mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144547 Approved by: https://github.com/kwen2501
1659 lines
70 KiB
Python
1659 lines
70 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
r"""Zero Redundancy Optimizer."""
|
|
|
|
import collections
|
|
import copy
|
|
import enum
|
|
import inspect
|
|
import io
|
|
import logging
|
|
from itertools import chain
|
|
from typing import Any, Callable, Optional, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
|
|
from torch.distributed.optim.utils import functional_optim_map
|
|
from torch.optim import Optimizer
|
|
|
|
|
|
__all__ = ["ZeroRedundancyOptimizer"]
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Credits: classy_vision/generic/distributed_util.py
|
|
def _recursive_copy_to_device(
|
|
value: Any,
|
|
non_blocking: bool,
|
|
device: torch.device,
|
|
) -> Any:
|
|
r"""
|
|
Recursively searches lists, tuples, dicts and copies tensors to device if possible.
|
|
|
|
Non-tensor values are passed as-is in the result.
|
|
|
|
.. note::
|
|
These are all copies, so if there are two objects that reference
|
|
the same object, then after this call, there will be two different objects
|
|
referenced on the device.
|
|
"""
|
|
if isinstance(value, torch.Tensor):
|
|
return value.to(device, non_blocking=non_blocking)
|
|
|
|
if isinstance(value, (list, tuple)):
|
|
values = [
|
|
_recursive_copy_to_device(val, non_blocking=non_blocking, device=device)
|
|
for val in value
|
|
]
|
|
return values if isinstance(value, list) else tuple(values)
|
|
|
|
if isinstance(value, collections.abc.Mapping):
|
|
return {
|
|
key: _recursive_copy_to_device(
|
|
val, non_blocking=non_blocking, device=device
|
|
)
|
|
for key, val in value.items()
|
|
}
|
|
|
|
return value
|
|
|
|
|
|
def _is_trainable(param: torch.Tensor) -> bool:
|
|
r"""Return if a parameter is trainable, where trainability is equivalent to requiring a gradient."""
|
|
return param.requires_grad
|
|
|
|
|
|
def _broadcast_object(
|
|
obj: Any,
|
|
src_rank: int,
|
|
group: object = dist.group.WORLD,
|
|
device: torch.device = torch.device("cpu"),
|
|
) -> Any:
|
|
r"""
|
|
Broadcasts an object to the given group.
|
|
|
|
It will be sending the object if called from the source rank and receiving
|
|
the object otherwise.
|
|
|
|
Arguments:
|
|
obj: object to broadcast; only used if called on the source rank.
|
|
src_rank (int): source rank.
|
|
group (``ProcessGroup``, optional): group used for the broadcast
|
|
(default: ``dist.group.WORLD``).
|
|
device (``torch.device``, optional): device to send from or receive
|
|
to (default: ``torch.device("cpu")``).
|
|
|
|
Returns:
|
|
The broadcasted object.
|
|
"""
|
|
if dist.get_rank() == src_rank:
|
|
# Send the object
|
|
buffer = io.BytesIO()
|
|
torch.save(obj, buffer)
|
|
data = bytearray(buffer.getbuffer())
|
|
length_tensor = torch.LongTensor([len(data)]).to(device)
|
|
data_send_tensor = torch.ByteTensor(data).to(device)
|
|
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
|
|
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
|
|
else:
|
|
# Receive the object
|
|
length_tensor = torch.LongTensor([0]).to(device)
|
|
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
|
|
data_recv_tensor = torch.empty(
|
|
[int(length_tensor.item())], dtype=torch.uint8, device=device
|
|
)
|
|
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
|
|
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
|
|
obj = torch.load(buffer, map_location=device, weights_only=False)
|
|
return obj
|
|
|
|
|
|
class _ZeROJoinHook(JoinHook):
|
|
def __init__(self, zero):
|
|
assert isinstance(zero, ZeroRedundancyOptimizer), (
|
|
"ZeRO join hook requires passing in a ZeroRedundancyOptimizer "
|
|
"instance as the state"
|
|
)
|
|
self.zero = zero
|
|
super().__init__()
|
|
|
|
def main_hook(self):
|
|
"""
|
|
Perform an optimizer step.
|
|
|
|
This step updates the joined process's shard of
|
|
the parameters and broadcasts those parameters.
|
|
"""
|
|
self.zero.step()
|
|
|
|
|
|
class _DDPBucketAssignment:
|
|
r"""
|
|
Represent a :class:`DistributedDataParallel` bucket assignment.
|
|
|
|
This means that a (possibly non-strict) subset of the parameters corresponding to
|
|
a DDP bucket assigned to a rank to update.
|
|
|
|
Attributes:
|
|
bucket_index (int): index of the bucket determined by the DDP gradient
|
|
bucket all-reduce order.
|
|
parameters (List[torch.Tensor]): model parameters in the bucket
|
|
assigned to this rank.
|
|
offset (int): offset into the :class:`GradBucket` 's :meth:`parameters`
|
|
giving the index of the first element in the passed-in
|
|
``parameters``; this equivalently indexes into the
|
|
:class:`GradBucket` 's :meth:`gradients`.
|
|
device (torch.device): device on which the parameters are stored.
|
|
tensor (torch.Tensor): flattened tensor giving the data of the
|
|
parameter subset assigned to the rank.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
bucket_index: int,
|
|
parameters: list[torch.Tensor],
|
|
offset: int,
|
|
):
|
|
self.bucket_index = bucket_index
|
|
self.parameters = parameters
|
|
self.offset = offset
|
|
if len(self.parameters) == 0:
|
|
raise ValueError("Empty bucket assignment")
|
|
# DDP guarantees all parameters in the bucket have the same device
|
|
self.device: torch.device = self.parameters[0].device
|
|
self.tensor: Optional[torch.Tensor] = None
|
|
|
|
|
|
class _OverlapStatus(enum.IntEnum):
|
|
r"""
|
|
Define possible statuses that :class:`ZeroRedundancyOptimizer` can be in when overlapping with :class:`DistributedDataParallel`.
|
|
|
|
Attributes:
|
|
``UNINITIALIZED``: The ZeRO instance is effectively uninitialized and
|
|
is waiting for DDP to finalize its bucketing.
|
|
``DDP_HAS_REBUILT_BUCKETS``: DDP has rebuilt its buckets, meaning that
|
|
its bucketing is finalized. The ZeRO instance can now collect the
|
|
necessary information about the DDP bucketing.
|
|
``INITIALIZED``: The ZeRO instance is fully initialized and can now
|
|
optimize parameters.
|
|
"""
|
|
|
|
UNINITIALIZED = 0
|
|
DDP_HAS_REBUILT_BUCKETS = 1
|
|
INITIALIZED = 2
|
|
|
|
|
|
class _OverlapInfo:
|
|
r"""
|
|
Information needed by :class:`ZeroRedundancyOptimizer` to overlap with :class:`DistributedDataParallel`.
|
|
|
|
Arguments:
|
|
world_size (int): world size of the process group being used.
|
|
|
|
Attributes:
|
|
shard_buckets (bool): if ``True``, then the assignment of each
|
|
:class:`DistributedDataParallel` bucket is partitioned across
|
|
possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e.
|
|
across possibly multiple ranks) to approximate uniformity following
|
|
a threshold given by the total parameter size divided by the world
|
|
size; if ``False``, then each bucket is wholly assigned to a single
|
|
:class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank);
|
|
this should be set to the value passed into the hook constructor.
|
|
status (_OverlapStatus): current status; see :class:`_OverlapStatus`
|
|
for more information.
|
|
params_per_bucket (List[List[torch.Tensor]]): ``params_per_bucket[i]``
|
|
gives the model parameters in the ``i``th bucket.
|
|
params_per_rank (List[List[torch.Tensor]]): ``params_per_rank[i]``
|
|
gives the model parameters assigned to the ``i``th rank, where the
|
|
parameters are grouped by increasing bucket indices.
|
|
offsets (Dict[int, int]): maps from bucket index to the offset in
|
|
``self.params_per_rank[rank]`` giving the index of the first
|
|
parameter in that bucket, where ``rank`` is this process's own
|
|
rank; the keys of this :class:`dict` are the bucket indices
|
|
assigned to this rank.
|
|
num_bucket_assignments (int): total number of bucket assignments across
|
|
all ranks; this is equal to the number of
|
|
:class:`DistributedDataParallel` gradient buckets if
|
|
``shard_buckets=False`` and possibly greater otherwise.
|
|
total_size (int, optional): total size of all buckets (i.e. sum of
|
|
``param.numel()`` for all ``param`` across all buckets) if
|
|
``shard_buckets=True``; otherwise, ``None``.
|
|
broadcast_handles (List[Work]): :class:`list` of async work handles for
|
|
the parameter broadcasts.
|
|
bucket_index_to_future (Dict[int, torch.futures.Future]):
|
|
:class:`dict` mapping bucket index to the corresponding all-reduce
|
|
future.
|
|
bucket_index_to_bucket (Dict[int, dist.GradBucket]): :class:`dict`
|
|
mapping bucket index to the corresponding bucket.
|
|
bucket_indices_seen (List[int]): :class:`list` of the bucket indices
|
|
seen on this iteration.
|
|
"""
|
|
|
|
def __init__(self, world_size) -> None:
|
|
self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED
|
|
self.shard_buckets: bool = False
|
|
|
|
# Modified per bucket reconstruction
|
|
self.params_per_bucket: list[list[torch.Tensor]] = []
|
|
self.params_per_rank: list[list[torch.Tensor]] = [[] for _ in range(world_size)]
|
|
self.offsets: dict[int, int] = {}
|
|
# Group Ranks
|
|
self.assigned_ranks_per_bucket: list[set[int]] = []
|
|
self.num_bucket_assignments: int = 0
|
|
self.total_size: Optional[int] = None
|
|
|
|
# Modified per iteration
|
|
self.broadcast_handles: list[Any] = []
|
|
self.bucket_indices_seen: list[int] = []
|
|
# Used by `hook_with_zero_step()`
|
|
self.bucket_index_to_future: dict[int, torch.futures.Future] = {}
|
|
self.bucket_index_to_bucket: dict[int, dist.GradBucket] = {}
|
|
|
|
def wait_for_broadcasts(self) -> None:
|
|
r"""
|
|
Wait for all parameter broadcasts.
|
|
|
|
This function should be called once all broadcasts have been scheduled,
|
|
meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles``
|
|
in preparation for the next iteration.
|
|
"""
|
|
assert len(self.broadcast_handles) == self.num_bucket_assignments, (
|
|
f"Missing at least one broadcast handle on rank {dist.get_rank()}"
|
|
)
|
|
_ = [x.wait() for x in self.broadcast_handles]
|
|
self.broadcast_handles.clear()
|
|
|
|
def clear_per_iter_info(self) -> None:
|
|
r"""
|
|
Clear the data structures that are modified per-iteration.
|
|
|
|
This function should be called at the end of an iteration.
|
|
"""
|
|
self.bucket_indices_seen.clear()
|
|
self.bucket_index_to_future.clear()
|
|
self.bucket_index_to_bucket.clear()
|
|
|
|
|
|
class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
|
r"""
|
|
Wrap an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` and shards its states across ranks in the group.
|
|
|
|
The sharing is done as described by ZeRO_.
|
|
|
|
The local optimizer instance in each rank is only
|
|
responsible for updating approximately ``1 / world_size`` parameters and
|
|
hence only needs to keep ``1 / world_size`` optimizer states. After
|
|
parameters are updated locally, each rank will broadcast its parameters to
|
|
all other peers to keep all model replicas in the same state.
|
|
``ZeroRedundancyOptimizer`` can be used in conjunction with
|
|
:class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak
|
|
memory consumption.
|
|
|
|
``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number
|
|
of parameters at each rank. Each parameter belongs to a single rank and is
|
|
not divided among ranks. The partition is arbitrary and might not match the
|
|
the parameter registration or usage order.
|
|
|
|
Arguments:
|
|
params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s
|
|
or :class:`dict` s giving all parameters, which will be sharded
|
|
across ranks.
|
|
|
|
Keyword Args:
|
|
optimizer_class (:class:`torch.nn.Optimizer`): the class of the local
|
|
optimizer.
|
|
process_group (``ProcessGroup``, optional): ``torch.distributed``
|
|
``ProcessGroup`` (default: ``dist.group.WORLD`` initialized by
|
|
:meth:`torch.distributed.init_process_group`).
|
|
parameters_as_bucket_view (bool, optional): if ``True``, parameters are
|
|
packed into buckets to speed up communication, and ``param.data``
|
|
fields point to bucket views at different offsets; if ``False``,
|
|
each individual parameter is communicated separately, and each
|
|
``params.data`` stays intact (default: ``False``).
|
|
overlap_with_ddp (bool, optional): if ``True``, :meth:`step` is
|
|
overlapped with :class:`DistributedDataParallel` 's gradient
|
|
synchronization; this requires (1) either a functional optimizer
|
|
for the ``optimizer_class`` argument or one with a functional
|
|
equivalent and (2) registering a DDP communication hook
|
|
constructed from one of the functions in ``ddp_zero_hook.py``;
|
|
parameters are packed into buckets matching those in
|
|
:class:`DistributedDataParallel`, meaning that the
|
|
``parameters_as_bucket_view`` argument is ignored.
|
|
If ``False``, :meth:`step` runs disjointly after the backward pass
|
|
(per normal).
|
|
(default: ``False``)
|
|
**defaults: any trailing arguments, which are forwarded to the local
|
|
optimizer.
|
|
|
|
Example::
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> import torch.nn as nn
|
|
>>> from torch.distributed.optim import ZeroRedundancyOptimizer
|
|
>>> from torch.nn.parallel import DistributedDataParallel as DDP
|
|
>>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
|
|
>>> ddp = DDP(model, device_ids=[rank])
|
|
>>> opt = ZeroRedundancyOptimizer(
|
|
>>> ddp.parameters(),
|
|
>>> optimizer_class=torch.optim.Adam,
|
|
>>> lr=0.01
|
|
>>> )
|
|
>>> ddp(inputs).sum().backward()
|
|
>>> opt.step()
|
|
|
|
.. warning::
|
|
Currently, ``ZeroRedundancyOptimizer`` requires that all of the
|
|
passed-in parameters are the same dense type.
|
|
|
|
.. warning::
|
|
If you pass ``overlap_with_ddp=True``, be wary of the following: Given
|
|
the way that overlapping :class:`DistributedDataParallel` with
|
|
:class:`ZeroRedundancyOptimizer` is currently implemented, the first
|
|
two or three training iterations do not perform parameter updates in
|
|
the optimizer step, depending on if ``static_graph=False`` or
|
|
``static_graph=True``, respectively. This is because it needs
|
|
information about the gradient bucketing strategy used by
|
|
:class:`DistributedDataParallel`, which is not finalized until the
|
|
second forward pass if ``static_graph=False`` or until the third
|
|
forward pass if ``static_graph=True``. To adjust for this, one option
|
|
is to prepend dummy inputs.
|
|
|
|
.. warning:: ZeroRedundancyOptimizer is experimental and subject to change.
|
|
|
|
.. _ZeRO: https://arxiv.org/abs/1910.02054
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
params,
|
|
optimizer_class: type[Optimizer],
|
|
process_group: Optional[Any] = None,
|
|
parameters_as_bucket_view: bool = False,
|
|
overlap_with_ddp: bool = False,
|
|
**defaults: Any,
|
|
):
|
|
r"""Init."""
|
|
# Perform type and assumption checks on the input parameters
|
|
params = self._verify_and_init_params(params)
|
|
self._verify_same_dense_param_type()
|
|
|
|
# NOTE: The parent constructor uses `add_param_group()` which is
|
|
# partially overloaded in ZeroRedundancyOptimizer, so we use the
|
|
# `initialized` flag to dissociate the behaviour of `add_param_group()`
|
|
# between the parent and child.
|
|
self.initialized = False
|
|
|
|
Optimizer.__init__(self, params, defaults)
|
|
Joinable.__init__(self)
|
|
# Now, all parameters are held in both `self._all_params` and
|
|
# `self.param_groups`
|
|
|
|
# Internal data structures (`_cache` indicates lazily evaluated)
|
|
self._param_to_rank_cache: dict[torch.Tensor, int] = {}
|
|
self._param_to_index_cache: dict[torch.Tensor, int] = {}
|
|
self._partition_parameters_cache: list[list[dict]] = []
|
|
self._index_to_param_cache: list[torch.Tensor] = []
|
|
self._device_to_params_per_rank_cache: dict[
|
|
torch.device, list[list[torch.Tensor]]
|
|
] = {}
|
|
self._bucket_assignments_per_rank_cache: list[
|
|
dict[int, _DDPBucketAssignment]
|
|
] = []
|
|
self._is_trainable_mask = self._get_is_trainable_mask()
|
|
|
|
# Default device for collective communication and buckets
|
|
self._default_device = self._all_params[0].device
|
|
|
|
self.process_group = (
|
|
process_group if process_group is not None else dist.group.WORLD
|
|
)
|
|
self.world_size: int = dist.get_world_size(self.process_group)
|
|
self.rank: int = dist.get_rank(self.process_group)
|
|
self.global_rank: int = dist.distributed_c10d.get_global_rank(
|
|
self.process_group, self.rank
|
|
)
|
|
|
|
self._overlap_with_ddp: bool = overlap_with_ddp
|
|
self._optim_defaults = defaults
|
|
self._optim_constructor = self._get_optimizer_constructor(optimizer_class)
|
|
|
|
# If `overlap_with_ddp=True`, local optimizer initialization is delayed
|
|
# to run time after the necessary information has been collected
|
|
if not overlap_with_ddp:
|
|
self._init_local_optimizer()
|
|
else:
|
|
self._overlap_info: _OverlapInfo = _OverlapInfo(self.world_size)
|
|
if parameters_as_bucket_view:
|
|
logger.warning(
|
|
"`parameters_as_bucket_view=True` will be ignored since "
|
|
"`overlap_with_ddp=True`; instead, a different bucketing "
|
|
"strategy will be used"
|
|
)
|
|
|
|
# `self._buckets` is used if `parameters_as_bucket_view=True`, in
|
|
# which case parameter data is flattened into contiguous bucket tensors
|
|
self.parameters_as_bucket_view = parameters_as_bucket_view
|
|
self._buckets: list[list[torch.Tensor]] = []
|
|
self._build_param_buckets()
|
|
|
|
# Optional consolidated optimizer state, only populated if this rank
|
|
# is the target in `consolidate_state_dict()`
|
|
self._all_state_dicts: list[dict[str, Any]] = []
|
|
|
|
self.initialized = True
|
|
|
|
def _clear_cache(self) -> None:
|
|
r"""Clear the cached data structures giving partition information."""
|
|
self._partition_parameters_cache.clear()
|
|
self._param_to_rank_cache.clear()
|
|
self._index_to_param_cache.clear()
|
|
self._param_to_index_cache.clear()
|
|
self._device_to_params_per_rank_cache.clear()
|
|
self._bucket_assignments_per_rank_cache.clear()
|
|
|
|
def add_param_group(self, param_group: dict[str, Any]) -> None:
|
|
r"""
|
|
Add a parameter group to the :class:`Optimizer` 's ``param_groups``.
|
|
|
|
This can be useful when fine tuning a pre-trained network, as frozen
|
|
layers can be made trainable and added to the :class:`Optimizer` as
|
|
training progresses.
|
|
|
|
Arguments:
|
|
param_group (dict): specifies the parameters to be optimized and
|
|
group-specific optimization options.
|
|
|
|
.. warning:: This method handles updating the shards on all partitions
|
|
but needs to be called on all ranks. Calling this on a subset of
|
|
the ranks will cause the training to hang because communication
|
|
primitives are called depending on the managed parameters and
|
|
expect all the ranks to participate on the same set of parameters.
|
|
"""
|
|
if self.initialized and self._overlap_with_ddp:
|
|
raise RuntimeError(
|
|
"ZeroRedundancyOptimizer with `overlap_with_ddp=True` only "
|
|
"supports a single parameter group"
|
|
)
|
|
|
|
super().add_param_group(param_group)
|
|
# NOTE: The rest of the method assumes that the call to the parent's
|
|
# `add_param_group()` appends the new parameter group and preserves
|
|
# the previous parameter-group ordering
|
|
|
|
if self.initialized:
|
|
# Force a re-partitioning of the parameters
|
|
self._clear_cache()
|
|
param_groups = self._partition_parameters()[self.rank]
|
|
# NOTE: All parameters in the old parameter groups should be
|
|
# assigned to the same ranks so that the local optimizers do not
|
|
# need to be reinitialized
|
|
|
|
# Add the parameters assigned to this rank from the new parameter
|
|
# group to the local optimizer, if any
|
|
if len(param_groups) == len(self.optim.param_groups) + 1:
|
|
self.optim.add_param_group(param_groups[-1])
|
|
|
|
# Update the bucketing strategy accordingly
|
|
if self.parameters_as_bucket_view:
|
|
self._build_param_buckets()
|
|
|
|
def consolidate_state_dict(self, to: int = 0) -> None:
|
|
r"""
|
|
Consolidate a list of ``state_dict`` s (one per rank) on the target rank.
|
|
|
|
Arguments:
|
|
to (int): the rank that receives the optimizer states (default: 0).
|
|
|
|
Raises:
|
|
RuntimeError: if ``overlap_with_ddp=True`` and this method is
|
|
called before this :class:`ZeroRedundancyOptimizer` instance
|
|
has been fully initialized, which happens once
|
|
:class:`DistributedDataParallel` gradient buckets have been
|
|
rebuilt.
|
|
|
|
.. warning:: This needs to be called on all ranks.
|
|
"""
|
|
self._check_overlap_initialized()
|
|
|
|
# Sync the exposed `param_groups` attributes to the local optimizer in
|
|
# case they have been updated
|
|
self._sync_param_groups(self.param_groups, self.optim.param_groups)
|
|
|
|
# Pull the sharded state from all ranks and store them in rank order
|
|
empty_messenger = torch.tensor(
|
|
[0], dtype=torch.uint8, device=self._default_device
|
|
)
|
|
|
|
# NOTE: We wastefully use `broadcast()` (e.g. instead of `gather()`)
|
|
# due to compatibility issues with NCCL backend; a possible follow-up
|
|
# is to move all sharded state management to RPC RRef
|
|
self._all_state_dicts = []
|
|
for rank in range(self.world_size):
|
|
global_rank = dist.distributed_c10d.get_global_rank(
|
|
self.process_group, rank
|
|
)
|
|
if self.rank == to:
|
|
# Consolidate all local `state_dict`s on this rank, storing on
|
|
# CPU to save GPU memory
|
|
if rank == self.rank:
|
|
# Directly append own optimizer state
|
|
self._all_state_dicts.append(
|
|
_recursive_copy_to_device(
|
|
self.optim.state_dict(),
|
|
non_blocking=True,
|
|
device=torch.device("cpu"),
|
|
)
|
|
)
|
|
else:
|
|
# Receive the optimizer state from the source rank
|
|
local_state_dict = _broadcast_object(
|
|
empty_messenger,
|
|
src_rank=global_rank,
|
|
group=self.process_group,
|
|
device=self._default_device,
|
|
)
|
|
self._all_state_dicts.append(
|
|
_recursive_copy_to_device(
|
|
local_state_dict,
|
|
non_blocking=True,
|
|
device=torch.device("cpu"),
|
|
)
|
|
)
|
|
else:
|
|
if rank == self.rank:
|
|
# Send the optimizer state to the target rank
|
|
_ = _broadcast_object(
|
|
self.optim.state_dict(),
|
|
src_rank=self.global_rank,
|
|
group=self.process_group,
|
|
device=self._default_device,
|
|
)
|
|
elif rank != to:
|
|
# Discard the received object; `broadcast()` is used for
|
|
# compatibility reasons
|
|
_ = _broadcast_object(
|
|
empty_messenger,
|
|
src_rank=global_rank,
|
|
group=self.process_group,
|
|
device=self._default_device,
|
|
)
|
|
|
|
def _verify_params_per_rank(
|
|
self,
|
|
params_per_rank: list[list[torch.Tensor]],
|
|
) -> None:
|
|
r"""
|
|
Verify ``params_per_rank`` for :meth:`_partition_parameters`.
|
|
|
|
The verification is done by checking that ``params_per_rank`` has length equal
|
|
to the world size and that it does not contain any parameters not passed into the
|
|
:class:`ZeroRedundancyOptimizer` constructor.
|
|
|
|
The parameters in ``params_per_rank`` being a strict subset of those
|
|
passed into the constructor is valid since some parameters may be
|
|
frozen.
|
|
|
|
Raises:
|
|
ValueError: if ``params_per_rank`` does not have length equal to
|
|
the world size or if it contains a parameter that was not
|
|
passed into the :class:`ZeroRedundancyOptimizer` constructor.
|
|
"""
|
|
if len(params_per_rank) != self.world_size:
|
|
raise ValueError(
|
|
"`params_per_rank` must have length equal to the world size"
|
|
)
|
|
all_params_set = set(self._all_params)
|
|
for params in params_per_rank:
|
|
for param in params:
|
|
if param not in all_params_set:
|
|
raise ValueError(
|
|
"Passing a new parameter in `params_per_rank` that "
|
|
"was not passed into the ZeroRedundancyOptimizer "
|
|
"constructor"
|
|
)
|
|
|
|
def _partition_param_group(
|
|
self, param_group: dict[str, Any], params_per_rank: list[list[torch.Tensor]]
|
|
) -> None:
|
|
r"""
|
|
Partition the parameter group ``param_group`` according to ``params_per_rank``.
|
|
|
|
The partition will modify the ``self._partition_parameters_cache``. This method should
|
|
only be used as a subroutine for :meth:`_partition_parameters`.
|
|
|
|
Arguments:
|
|
param_group (dict[str, Any]): a parameter group as normally defined
|
|
in an optimizer state.
|
|
params_per_rank (list[list[torch.Tensor]]): a :class:`list` of
|
|
length world size containing :class:`list` s of parameters to
|
|
assign to each rank.
|
|
"""
|
|
for rank, params in enumerate(params_per_rank):
|
|
rank_param_group = copy.copy(param_group)
|
|
rank_param_group["params"] = params
|
|
self._partition_parameters_cache[rank].append(rank_param_group)
|
|
|
|
def _partition_parameters(
|
|
self,
|
|
params_per_rank: Optional[list[list[torch.Tensor]]] = None,
|
|
) -> list[list[dict]]:
|
|
r"""
|
|
Partitions parameters across distributed data parallel ranks.
|
|
|
|
Arguments:
|
|
params_per_rank (list[list[torch.Tensor]], optional): a
|
|
:class:`list` of length world size containing :class:`list` s
|
|
of parameters to assign to each rank; this provides a way to
|
|
specify a partition manually.
|
|
If ``None``, the parameters are partitioned according to an
|
|
internal algorithm.
|
|
(default: ``None``)
|
|
|
|
Returns:
|
|
A :class:`list` where each element of the list contains the
|
|
``param_groups`` for a rank (which itself is a :class:`list` of
|
|
:class:`dict`); element 0 corresponds to rank 0, etc.; each rank
|
|
stores the ``param_groups`` for all ranks for the collective
|
|
communication in :meth:`step`.
|
|
|
|
Raises:
|
|
ValueError: see :meth:`_validate_params_per_rank`.
|
|
RuntimeError: if ``params_per_rank`` is not ``None`` and this
|
|
:class:`ZeroRedundancyOptimizer` instance is using more than
|
|
one parameter group.
|
|
"""
|
|
if params_per_rank is None:
|
|
# Partition the parameters optimizing for uniformity
|
|
if len(self._partition_parameters_cache) == 0:
|
|
self._partition_parameters_cache = [[] for _ in range(self.world_size)]
|
|
sizes = [0] * self.world_size
|
|
for param_group in self.param_groups:
|
|
param_group_params_per_rank: list[list] = [
|
|
[] for _ in range(self.world_size)
|
|
]
|
|
# Sort the parameters by size (largest first)
|
|
params_sorted = sorted(
|
|
param_group["params"], key=lambda t: t.numel(), reverse=True
|
|
)
|
|
for param in params_sorted:
|
|
# Greedily add the parameter to rank with smallest size so far
|
|
rank = self._get_min_index(sizes)
|
|
param_group_params_per_rank[rank].append(param)
|
|
sizes[rank] += param.numel()
|
|
# Apply the constructed partition of the parameter group
|
|
self._partition_param_group(
|
|
param_group, param_group_params_per_rank
|
|
)
|
|
|
|
return self._partition_parameters_cache
|
|
|
|
# Partition the parameters according to `params_per_rank`
|
|
assert len(self._partition_parameters_cache) == 0, (
|
|
"Specifying `params_per_rank` should only be done when the "
|
|
"parameters have not been partitioned yet"
|
|
)
|
|
if len(self.param_groups) != 1:
|
|
raise RuntimeError(
|
|
"Specifying `params_per_rank` only supports a single parameter group"
|
|
)
|
|
self._verify_params_per_rank(params_per_rank)
|
|
self._partition_parameters_cache = [[] for _ in range(self.world_size)]
|
|
|
|
# Apply the passed-in partition of the parameter group
|
|
param_group = self.param_groups[0]
|
|
self._partition_param_group(param_group, params_per_rank)
|
|
|
|
return self._partition_parameters_cache
|
|
|
|
@property
|
|
def _param_to_rank(self) -> dict[torch.Tensor, int]:
|
|
r""":class:`dict` mapping parameters to their assigned data parallel rank in the partition."""
|
|
if len(self._param_to_rank_cache) == 0:
|
|
for rank, param_groups in enumerate(self._partition_parameters()):
|
|
for param_group in param_groups:
|
|
for param in param_group["params"]:
|
|
self._param_to_rank_cache[param] = rank
|
|
return self._param_to_rank_cache
|
|
|
|
@property
|
|
def _param_to_index(self) -> dict[torch.Tensor, int]:
|
|
r"""
|
|
:class:`dict` mapping parameters to their indices in the global optimizer state.
|
|
|
|
NOTE: This assumes that the global optimizer state's indexing (in
|
|
``state_dict``) follows a linear ordering over the parameter groups.
|
|
"""
|
|
if len(self._param_to_index_cache) == 0:
|
|
self._param_to_index_cache = {
|
|
p: i
|
|
for i, p in enumerate(chain(*(g["params"] for g in self.param_groups)))
|
|
}
|
|
return self._param_to_index_cache
|
|
|
|
@property
|
|
def _index_to_param(self) -> list[torch.Tensor]:
|
|
r"""List mapping parameter indices in the global optimizer scheme to the actual params."""
|
|
if len(self._index_to_param_cache) == 0:
|
|
self._index_to_param_cache = list(
|
|
chain(*(g["params"] for g in self.param_groups))
|
|
)
|
|
return self._index_to_param_cache
|
|
|
|
def _broadcast_params_from_rank(self, rank: int):
|
|
r"""
|
|
Broadcast the shard of parameters from a given rank to all other ranks asynchronously.
|
|
|
|
Arguments:
|
|
rank (int): the source rank.
|
|
|
|
Returns:
|
|
A :class:`list` of async work handles for the ``broadcast()`` s
|
|
performed to synchronize the parameters.
|
|
"""
|
|
assert not self._overlap_with_ddp, (
|
|
"`_broadcast_params_from_rank()` should not be used if "
|
|
"`overlap_with_ddp=True`; instead, the broadcasting should "
|
|
"happen in the DDP communication hook"
|
|
)
|
|
handles = []
|
|
if self.parameters_as_bucket_view:
|
|
for dev_i_buckets in self._buckets:
|
|
bucket = dev_i_buckets[rank]
|
|
global_rank = dist.distributed_c10d.get_global_rank(
|
|
self.process_group, rank
|
|
)
|
|
handles.append(
|
|
dist.broadcast(
|
|
tensor=bucket,
|
|
src=global_rank,
|
|
group=self.process_group,
|
|
async_op=True,
|
|
)
|
|
)
|
|
else:
|
|
param_groups = self._partition_parameters()[rank]
|
|
global_rank = dist.distributed_c10d.get_global_rank(
|
|
self.process_group, rank
|
|
)
|
|
for param_group in param_groups:
|
|
handles.extend(
|
|
dist.broadcast(
|
|
tensor=param.data,
|
|
src=global_rank,
|
|
group=self.process_group,
|
|
async_op=True,
|
|
)
|
|
for param in param_group["params"]
|
|
)
|
|
return handles
|
|
|
|
def _sync_params(self):
|
|
r"""
|
|
Sync all parameter shards across the ranks.
|
|
|
|
This rank sends its shard of the parameters to all other ranks and
|
|
receives a shard from each other rank. This is done using
|
|
``broadcast()``. Parameters are sent bucket-by-bucket if
|
|
``parameters_as_bucket_view=True``and sent parameter-by-parameter
|
|
otherwise.
|
|
"""
|
|
handles = []
|
|
for rank in range(self.world_size):
|
|
handles.extend(self._broadcast_params_from_rank(rank))
|
|
_ = [x.wait() for x in handles]
|
|
|
|
@property
|
|
def _device_to_params_per_rank(
|
|
self,
|
|
) -> dict[torch.device, list[list[torch.Tensor]]]:
|
|
r"""
|
|
Return device parameters assigned per rank.
|
|
|
|
:class:`dict` mapping each device to a :class:`list` of the per-rank parameter
|
|
lists filtered to only include the parameters stored on that device.
|
|
Each per-rank parameter list gives the parameters assigned to that rank
|
|
to update.
|
|
|
|
This is used for constructing the parameter buckets if
|
|
``parameters_as_bucket_view=True``.
|
|
|
|
Let ``dev_i`` denote the ``i``th device for this rank. Then:
|
|
``dev_0`` maps to a list containing:
|
|
rank 0's assigned parameters stored on ``dev_0``,
|
|
rank 1's assigned parameters stored on ``dev_0``,
|
|
...
|
|
``dev_1`` maps to a list containing:
|
|
rank 0's assigned parameters stored on ``dev_1``,
|
|
rank 1's assigned parameters stored on ``dev_1``,
|
|
...
|
|
...
|
|
"""
|
|
assert self.parameters_as_bucket_view, (
|
|
"`_device_to_params_per_rank` should only be used if "
|
|
"`parameters_as_bucket_view=True`"
|
|
)
|
|
if len(self._device_to_params_per_rank_cache) == 0:
|
|
for rank, param_groups in enumerate(self._partition_parameters()):
|
|
for param_group in param_groups:
|
|
for param in param_group["params"]:
|
|
device = param.device
|
|
if device not in self._device_to_params_per_rank_cache:
|
|
self._device_to_params_per_rank_cache[device] = [
|
|
[] for _ in range(self.world_size)
|
|
]
|
|
self._device_to_params_per_rank_cache[device][rank].append(
|
|
param
|
|
)
|
|
return self._device_to_params_per_rank_cache
|
|
|
|
def _get_min_index(
|
|
self,
|
|
values: list[int],
|
|
disallowed_indices: Optional[set[int]] = None,
|
|
) -> int:
|
|
r"""
|
|
Return ``values.index(min(values))``, except only uses one pass.
|
|
|
|
It also excludes any indices in ``disallowed_indices`` if provided.
|
|
|
|
Arguments:
|
|
values: (List[int]): :class:`list` of values.
|
|
disallowed_indices (Optional[Set[int]]): indices that are
|
|
disallowed from being the returned min index.
|
|
"""
|
|
min_index = -1
|
|
min_value = float("inf")
|
|
for i, value in enumerate(values):
|
|
if disallowed_indices and i in disallowed_indices:
|
|
continue
|
|
if value < min_value:
|
|
min_value = value
|
|
min_index = i
|
|
assert min_index >= 0, "All indices are disallowed"
|
|
return min_index
|
|
|
|
def _assign_bucket_subset_to_rank(
|
|
self,
|
|
bucket_index: int,
|
|
bucket_params: list[torch.Tensor],
|
|
bucket_offset: int,
|
|
assigned_rank: int,
|
|
assigned_ranks_per_bucket: list[set[int]],
|
|
) -> None:
|
|
r"""
|
|
Assign ``bucket_params`` to the rank with the least size assigned so far and collects relevant information.
|
|
|
|
The model parameters given by ``bucket_params`` represents a (possibly non-strict)
|
|
subset of the parameters corresponding to a :class:`DistributedDataParallel` bucket.
|
|
|
|
Arguments:
|
|
bucket_index (int): index of the :class:`DistributedDataParallel`
|
|
gradient bucket.
|
|
bucket_params (List[torch.Tensor]): subset of the parameters
|
|
corresponding to the bucket to assign.
|
|
bucket_offset (int): offset giving the index of the first element
|
|
in ``bucket_params`` in the bucket's full parameter list.
|
|
assigned_rank (int): group rank to assign to.
|
|
assigned_ranks_per_bucket (List[Set[int]]): :class:`set` of group ranks
|
|
assigned to each bucket.
|
|
"""
|
|
overlap_info = self._overlap_info
|
|
if len(bucket_params) == 0:
|
|
raise ValueError("Empty bucket assignment")
|
|
params_per_rank = overlap_info.params_per_rank
|
|
offsets = overlap_info.offsets
|
|
|
|
self._bucket_assignments_per_rank_cache[assigned_rank][bucket_index] = (
|
|
_DDPBucketAssignment(bucket_index, bucket_params, bucket_offset)
|
|
)
|
|
if self.global_rank == assigned_rank:
|
|
offsets[bucket_index] = len(params_per_rank[assigned_rank])
|
|
params_per_rank[assigned_rank].extend(bucket_params)
|
|
assigned_ranks_per_bucket[bucket_index].add(assigned_rank)
|
|
self._overlap_info.num_bucket_assignments += 1
|
|
|
|
@property
|
|
def _bucket_assignments_per_rank(self) -> list[dict[int, _DDPBucketAssignment]]:
|
|
r"""
|
|
Return DDP bucket parameters assigned per rank.
|
|
|
|
:class:`list` of length world size consisting of :class:`dict` s
|
|
mapping bucket indices to :class:`_DDPBucketAssignment` s for each
|
|
rank.
|
|
"""
|
|
assert self._overlap_with_ddp, (
|
|
"`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`"
|
|
)
|
|
if len(self._bucket_assignments_per_rank_cache) > 0:
|
|
return self._bucket_assignments_per_rank_cache
|
|
|
|
overlap_info = self._overlap_info
|
|
assert overlap_info.status == _OverlapStatus.INITIALIZED
|
|
|
|
self._bucket_assignments_per_rank_cache = [{} for _ in range(self.world_size)]
|
|
params_per_bucket = overlap_info.params_per_bucket
|
|
|
|
if overlap_info.shard_buckets:
|
|
# Define the assignment threshold to approximate uniformity
|
|
assert overlap_info.total_size is not None, "`total_size` was not computed"
|
|
threshold = overlap_info.total_size / self.world_size # type: ignore[operator]
|
|
size_per_rank = [0 for _ in range(self.world_size)]
|
|
|
|
num_buckets = len(params_per_bucket)
|
|
overlap_info.assigned_ranks_per_bucket = [set() for _ in range(num_buckets)]
|
|
assigned_ranks_per_bucket = overlap_info.assigned_ranks_per_bucket
|
|
if not overlap_info.shard_buckets:
|
|
# Assign each DDP bucket entirely to a single rank
|
|
for bucket_index, bucket_params in enumerate(params_per_bucket):
|
|
assert len(bucket_params) > 0, "Empty bucket"
|
|
assigned_rank = self._get_assigned_rank(bucket_index)
|
|
self._assign_bucket_subset_to_rank(
|
|
bucket_index,
|
|
bucket_params,
|
|
0,
|
|
assigned_rank,
|
|
assigned_ranks_per_bucket,
|
|
)
|
|
else:
|
|
# Assign each DDP bucket to possibly multiple ranks
|
|
# Specifically, sort the DDP buckets by increasing size, and for
|
|
# each bucket, iteratively assign the maximal unassigned subset
|
|
# with size less than `threshold` to the rank with the least total
|
|
# size so far -- each such assignment is represented by a
|
|
# `_DDPBucketAssignment` instance and only contains parameters from
|
|
# a single DDP bucket
|
|
params_per_bucket_enum = sorted(
|
|
enumerate(params_per_bucket), key=lambda x: sum(p.numel() for p in x[1])
|
|
)
|
|
for bucket_index, bucket_params in params_per_bucket_enum:
|
|
assert len(bucket_params) > 0, "Empty bucket"
|
|
bucket_offset = 0
|
|
assignment_size = 0
|
|
for param_index, param in enumerate(bucket_params):
|
|
param_numel = param.numel()
|
|
if (
|
|
assignment_size + param_numel >= threshold
|
|
and param_index > bucket_offset
|
|
):
|
|
assigned_rank = self._get_min_index(
|
|
size_per_rank, assigned_ranks_per_bucket[bucket_index]
|
|
)
|
|
# Include up to but not including the parameter that
|
|
# exceeded the threshold
|
|
self._assign_bucket_subset_to_rank(
|
|
bucket_index,
|
|
bucket_params[bucket_offset:param_index],
|
|
bucket_offset,
|
|
assigned_rank,
|
|
assigned_ranks_per_bucket,
|
|
)
|
|
size_per_rank[assigned_rank] += assignment_size
|
|
bucket_offset = param_index
|
|
assignment_size = 0
|
|
assignment_size += param_numel
|
|
# Assign the remainder of the bucket so that no assignment
|
|
# spans across two buckets
|
|
assigned_rank = self._get_min_index(
|
|
size_per_rank, assigned_ranks_per_bucket[bucket_index]
|
|
)
|
|
self._assign_bucket_subset_to_rank(
|
|
bucket_index,
|
|
bucket_params[bucket_offset:],
|
|
bucket_offset,
|
|
assigned_rank,
|
|
assigned_ranks_per_bucket,
|
|
)
|
|
size_per_rank[assigned_rank] += assignment_size
|
|
|
|
return self._bucket_assignments_per_rank_cache
|
|
|
|
def _local_step(
|
|
self,
|
|
gradients: Optional[list[Optional[torch.Tensor]]] = None,
|
|
closure: Optional[Callable[[], float]] = None,
|
|
**kwargs: Any,
|
|
) -> Optional[float]:
|
|
r"""
|
|
Perform a single optimizer step without syncing parameters across ranks.
|
|
|
|
Arguments:
|
|
gradients (list[Optional[torch.Tensor]], optional): a :class:`list`
|
|
of length equal to the number of parameters assigned to this
|
|
rank containing gradient tensors or ``None`` as its elements;
|
|
a ``None`` in the :class:`list` indicates that the
|
|
corresponding parameter should not be updated.
|
|
If the argument itself is ``None``, then all parameters are
|
|
updated, and the gradients are assumed to be already populated.
|
|
(default: ``None``)
|
|
closure (Callable): a closure that re-evaluates the model and
|
|
returns the loss; optional for most optimizers and should be
|
|
``None`` if ``gradients`` is not ``None``; (default: ``None``)
|
|
Returns:
|
|
Optional loss depending on the underlying local optimizer.
|
|
|
|
.. warning::
|
|
The argument ``gradients`` should only be specified (i.e. not
|
|
``None``) if ``overlap_with_ddp=True``, in which case
|
|
:class:`ZeroRedundancyOptimizer` wraps a functional optimizer.
|
|
"""
|
|
Join.notify_join_context(self)
|
|
# Check if the model trainability has changed
|
|
is_trainable_mask = self._get_is_trainable_mask()
|
|
if is_trainable_mask != self._is_trainable_mask:
|
|
if self._overlap_with_ddp:
|
|
raise RuntimeError(
|
|
"ZeroRedundancyOptimizer with `overlap_with_ddp=True` "
|
|
"does not support changing parameter trainability at run "
|
|
"time"
|
|
)
|
|
logger.warning(
|
|
"ZeroRedundancyOptimizer detected that the trainable "
|
|
"parameters changed; rebuilding the parameter buckets if "
|
|
"enabled"
|
|
)
|
|
self._build_param_buckets()
|
|
self._is_trainable_mask = is_trainable_mask
|
|
|
|
# Sync the exposed `param_groups` attributes to the local optimizer in
|
|
# case they have been updated
|
|
self._sync_param_groups(self.param_groups, self.optim.param_groups)
|
|
|
|
# Run the optimizer step on this shard only
|
|
if gradients is None:
|
|
loss = (
|
|
self.optim.step(**kwargs)
|
|
if closure is None
|
|
else self.optim.step(closure=closure, **kwargs)
|
|
)
|
|
else:
|
|
assert self._overlap_with_ddp, (
|
|
"Specifying `gradients` should not "
|
|
"be used when `overlap_with_ddp=False`"
|
|
)
|
|
assert closure is None, (
|
|
"`closure` is not supported when using a local functional optimizer"
|
|
)
|
|
loss = self.optim.step(gradients=gradients)
|
|
|
|
# Sync any updated attributes in the local optimizer to the exposed
|
|
# `param_groups`
|
|
self._sync_param_groups(self.optim.param_groups, self.param_groups)
|
|
|
|
return loss
|
|
|
|
def step(
|
|
self,
|
|
closure: Optional[Callable[[], float]] = None,
|
|
**kwargs: Any,
|
|
) -> Optional[float]:
|
|
r"""
|
|
Perform a single optimizer step and syncs parameters across all ranks.
|
|
|
|
Arguments:
|
|
closure (Callable): a closure that re-evaluates the model and
|
|
returns the loss; optional for most optimizers.
|
|
Returns:
|
|
Optional loss depending on the underlying local optimizer.
|
|
|
|
.. note:: Any extra parameters are passed to the base optimizer as-is.
|
|
"""
|
|
if self._overlap_with_ddp:
|
|
logger.warning(
|
|
"`step()` should not be included in the training loop when "
|
|
"`overlap_with_ddp=True`"
|
|
)
|
|
return None
|
|
|
|
# Perform the local optimizer step
|
|
loss = self._local_step(closure=closure, **kwargs)
|
|
|
|
# Sync all of the updated parameter shards across the ranks
|
|
self._sync_params()
|
|
|
|
return loss
|
|
|
|
def join_hook(self, **kwargs):
|
|
r"""
|
|
Return the ZeRO join hook.
|
|
|
|
It enables training on uneven inputs by
|
|
shadowing the collective communications in the optimizer step.
|
|
|
|
Gradients must be properly set before this hook is called.
|
|
|
|
Arguments:
|
|
kwargs (dict): a :class:`dict` containing any keyword arguments
|
|
to modify the behavior of the join hook at run time; all
|
|
:class:`Joinable` instances sharing the same join context
|
|
manager are forwarded the same value for ``kwargs``.
|
|
|
|
This hook does not support any keyword arguments; i.e. ``kwargs`` is
|
|
unused.
|
|
"""
|
|
return _ZeROJoinHook(self)
|
|
|
|
@property
|
|
def join_device(self) -> torch.device:
|
|
r"""Return default device."""
|
|
return self._default_device
|
|
|
|
@property
|
|
def join_process_group(self) -> Any:
|
|
r"""Return process group."""
|
|
return self.process_group
|
|
|
|
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
|
r"""
|
|
Load the state pertaining to the given rank from the input ``state_dict``, updating the local optimizer as needed.
|
|
|
|
Arguments:
|
|
state_dict (dict): optimizer state; should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
|
|
Raises:
|
|
RuntimeError: if ``overlap_with_ddp=True`` and this method is
|
|
called before this :class:`ZeroRedundancyOptimizer` instance
|
|
has been fully initialized, which happens once
|
|
:class:`DistributedDataParallel` gradient buckets have been
|
|
rebuilt.
|
|
"""
|
|
self._check_overlap_initialized()
|
|
|
|
for index, value in state_dict["state"].items():
|
|
param = self._index_to_param[index]
|
|
if self._param_to_rank[param] != self.rank:
|
|
# Clear any state irrelevant to this rank
|
|
state_dict["state"][index] = None
|
|
else:
|
|
# Load the parameter state to the local optimizer
|
|
self.optim.state[param] = _recursive_copy_to_device(
|
|
value, non_blocking=True, device=param.device
|
|
)
|
|
# Force zero-dimensional tensors (like Adam "step") on CPU
|
|
for state_name, state_value in self.optim.state[param].items():
|
|
if torch.is_tensor(state_value) and state_value.dim() == 0:
|
|
self.optim.state[param][state_name] = state_value.cpu()
|
|
|
|
super().load_state_dict(state_dict)
|
|
|
|
# Sync the input state with the exposed and local optimizer states
|
|
self._sync_param_groups(state_dict["param_groups"], self.param_groups)
|
|
self._sync_param_groups(self.param_groups, self.optim.param_groups)
|
|
|
|
def state_dict(self) -> dict[str, Any]:
|
|
r"""
|
|
Return the last global optimizer state known to this rank.
|
|
|
|
.. warning:
|
|
If the state has not been consolidated to this rank, this raises a
|
|
runtime error, and even if it has, the state may not be up-to-date,
|
|
depending on when :meth:`consolidate_state_dict` was last called.
|
|
|
|
Raises:
|
|
RuntimeError: if ``overlap_with_ddp=True`` and this method is
|
|
called before this :class:`ZeroRedundancyOptimizer` instance
|
|
has been fully initialized, which happens once
|
|
:class:`DistributedDataParallel` gradient buckets have been
|
|
rebuilt; or if this method is called without a preceding call
|
|
to :meth:`consolidate_state_dict`.
|
|
"""
|
|
self._check_overlap_initialized()
|
|
|
|
if len(self._all_state_dicts) == 0:
|
|
raise RuntimeError(
|
|
"Optimizer state has not been consolidated on this rank. "
|
|
f"Please call `consolidate_state_dict(to={self.rank})` on "
|
|
"all ranks beforehand if you meant to save the global state."
|
|
)
|
|
|
|
# Get the possibly-stale global optimizer state that uses global
|
|
# parameter indexing
|
|
state_dict = super().state_dict()
|
|
|
|
# Update the global optimizer state with local state information,
|
|
# factoring in the translation from local to global indexing
|
|
for rank, local_state_dict in enumerate(self._all_state_dicts):
|
|
local_param_groups = local_state_dict["param_groups"]
|
|
global_param_groups = self._partition_parameters()[rank]
|
|
assert len(local_param_groups) == len(global_param_groups), (
|
|
"Mismatch between number of local and global parameter groups"
|
|
)
|
|
|
|
for local_param_group, global_param_group in zip(
|
|
local_param_groups, global_param_groups
|
|
):
|
|
# `local_param_group` stores local indices, while
|
|
# `global_param_group` stores the tensors directly
|
|
local_param_indices = local_param_group["params"]
|
|
global_params = global_param_group["params"]
|
|
|
|
assert len(local_param_indices) == len(global_params), (
|
|
"Mismatch between number of local and global parameters in parameter group"
|
|
)
|
|
for local_param_index, global_param in zip(
|
|
local_param_indices, global_params
|
|
):
|
|
# Update the global parameter state, if any
|
|
if local_param_index in local_state_dict["state"]:
|
|
global_param_index = self._param_to_index[global_param]
|
|
state_dict["state"][global_param_index] = local_state_dict[
|
|
"state"
|
|
][local_param_index]
|
|
|
|
# Sort the parameters in the state
|
|
state_dict["state"] = dict(sorted(state_dict["state"].items()))
|
|
return state_dict
|
|
|
|
@staticmethod
|
|
def _sync_param_groups(
|
|
src_param_groups: list[dict[Any, Any]],
|
|
dst_param_groups: list[dict[Any, Any]],
|
|
) -> None:
|
|
r"""
|
|
Sync the attributes from the source parameter groups to the destination parameter groups.
|
|
|
|
Example attributes include learning rate or scheduler attributes. The
|
|
two parameter groups should have the same length (i.e. same number of
|
|
parameter groups).
|
|
|
|
Arguments:
|
|
src_param_groups (list[dict]): parameter groups giving the
|
|
attribute settings to copy.
|
|
dst_param_groups (list[dict]): parameter groups giving the
|
|
attribute settings to set.
|
|
"""
|
|
assert len(src_param_groups) == len(dst_param_groups), (
|
|
"Mismatch between number of source and destination parameter groups"
|
|
)
|
|
for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups):
|
|
# Sync all attributes except the parameters
|
|
for attr in filter(lambda x: x != "params", src_param_group.keys()):
|
|
dst_param_group[attr] = src_param_group[attr]
|
|
|
|
def _build_param_buckets(self) -> None:
|
|
r"""
|
|
Build parameter buckets if ``parameters_as_bucket_view=True``.
|
|
|
|
For each device that stores this rank's parameters, there is a
|
|
bucket (represented as a tensor) containing all of the parameters on
|
|
that device that are assigned to a given rank in the parameter update
|
|
partition.
|
|
|
|
This method is called in the constructor and any time parameter
|
|
trainability is changed.
|
|
|
|
.. warning::
|
|
The current implementation assumes that all of the parameters in a
|
|
bucket are of the same dense type when allocating the bucket's
|
|
tensor.
|
|
|
|
.. warning::
|
|
If the model parameters are stored across more than one device,
|
|
then the storage partitioning must be the same across all
|
|
processes in order for parameter synchronization to work.
|
|
"""
|
|
if not self.parameters_as_bucket_view or self._overlap_with_ddp:
|
|
return
|
|
|
|
# `self._buckets[i][j]` are the parameters stored on device i and
|
|
# assigned to rank j
|
|
num_devices = len(self._device_to_params_per_rank)
|
|
self._buckets = [[] for _ in range(num_devices)] # type: ignore[assignment]
|
|
|
|
for dev_i, (device, params_per_rank) in enumerate(
|
|
self._device_to_params_per_rank.items()
|
|
):
|
|
for params in params_per_rank:
|
|
bucket_size = 0
|
|
dtype = None
|
|
trainable_params = []
|
|
for param in params:
|
|
if not _is_trainable(param):
|
|
# Clone in case the parameter was previously part of
|
|
# a bucket to avoid the data from being destroyed
|
|
param.data = param.data.detach().clone()
|
|
else:
|
|
bucket_size += param.numel()
|
|
trainable_params.append(param)
|
|
dtype = param.dtype # assumes all same dtype
|
|
|
|
if bucket_size == 0:
|
|
# Create a dummy bucket if there are no parameters
|
|
bucket = torch.zeros(1, device=device)
|
|
else:
|
|
# Construct the bucket (assuming all dense and same dtype)
|
|
bucket = torch.empty(bucket_size, dtype=dtype, device=device)
|
|
offset = 0
|
|
for param in trainable_params:
|
|
offset_next = offset + param.numel()
|
|
bucket[offset:offset_next].copy_(param.data.flatten())
|
|
param.data = bucket[offset:offset_next].view_as(param.data)
|
|
offset = offset_next
|
|
self._buckets[dev_i].append(bucket) # type: ignore[arg-type]
|
|
|
|
def _build_ddp_param_buckets(self) -> None:
|
|
r"""
|
|
Build the DDP bucket with parameters assigned to this rank.
|
|
|
|
For each DDP bucket with parameters assigned to this rank, flattens the
|
|
data of those parameters into a single tensor and saves the tensor to
|
|
the ``tensor`` attribute in the corresponding
|
|
:class:`_DDPBucketAssignment` instance stored in
|
|
``self._bucket_assignments_per_rank``.
|
|
|
|
:class:`DistributedDataParallel` guarantees that the parameters
|
|
corresponding to a gradient bucket have the same device and the same
|
|
dtype.
|
|
"""
|
|
for bucket_assignments in self._bucket_assignments_per_rank:
|
|
for bucket_assignment in bucket_assignments.values():
|
|
params = bucket_assignment.parameters
|
|
bucket_size = 0
|
|
dtype = None
|
|
for param in params:
|
|
assert _is_trainable(param), (
|
|
"Model parameter "
|
|
"corresponding to a gradient in a DDP bucket should "
|
|
"require a gradient"
|
|
)
|
|
bucket_size += param.numel()
|
|
dtype = param.dtype # assumes all same dtype
|
|
assert bucket_size > 0, "Empty bucket"
|
|
|
|
# Construct the bucket tensor (assuming all dense and same dtype)
|
|
tensor = torch.empty(
|
|
bucket_size, dtype=dtype, device=bucket_assignment.device
|
|
)
|
|
offset = 0
|
|
for param in params:
|
|
offset_next = offset + param.numel()
|
|
tensor[offset:offset_next].copy_(param.data.flatten())
|
|
param.data = tensor[offset:offset_next].view_as(param.data)
|
|
offset = offset_next
|
|
bucket_assignment.tensor = tensor
|
|
|
|
def _verify_and_init_params(
|
|
self,
|
|
params: Any,
|
|
) -> Union[list[torch.Tensor], list[dict]]:
|
|
r"""
|
|
Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters.
|
|
|
|
The initializagtion will first make sure that provided ``params`` is valid.
|
|
|
|
Arguments:
|
|
params (Any): Candidate parameter list or parameter groups to verify.
|
|
|
|
Raises:
|
|
TypeError: ``params`` has an invalid type.
|
|
ValueError: ``params`` is empty.
|
|
|
|
Returns:
|
|
The persistent form of ``params`` to be passed into the parent
|
|
:class:`Optimizer` constructor -- i.e. returns ``params`` as a
|
|
:class:`list` to ensure that it can be iterated over again.
|
|
"""
|
|
if isinstance(params, torch.Tensor):
|
|
raise TypeError(
|
|
"`params` argument should be an iterable of "
|
|
f"Tensors, but got {torch.typename(params)}"
|
|
)
|
|
try:
|
|
all_params = list(params)
|
|
except TypeError as e:
|
|
raise TypeError(
|
|
"`params` argument should be an iterable of Tensors"
|
|
f" or dicts, but got {torch.typename(params)}"
|
|
) from e
|
|
if len(all_params) == 0:
|
|
raise ValueError("ZeroRedundancyOptimizer got an empty parameter list")
|
|
all_tensors = True
|
|
all_dicts = True
|
|
for param in all_params:
|
|
all_tensors &= isinstance(param, torch.Tensor)
|
|
all_dicts &= isinstance(param, dict)
|
|
if not all_tensors and not all_dicts:
|
|
raise TypeError(
|
|
"`params` argument should be an iterable of Tensors or dicts"
|
|
)
|
|
# Ensure that `self._all_params` contains a list of all parameters
|
|
if all_tensors:
|
|
self._all_params = all_params
|
|
elif all_dicts:
|
|
self._all_params = []
|
|
# `all_params` contains parameter groups (not parameters)
|
|
for param_group in all_params:
|
|
if "params" not in param_group:
|
|
raise ValueError(
|
|
"Each parameter group passed-in via `params` must "
|
|
"have a 'params' key mapping to the parameters in "
|
|
"the group"
|
|
)
|
|
self._all_params.extend(param_group["params"])
|
|
return all_params
|
|
|
|
def _verify_same_dense_param_type(self) -> None:
|
|
r"""
|
|
Verify that all parameters are of the same dense type.
|
|
|
|
The method assumes that ``self._all_params`` has been initialized
|
|
and is non-empty.
|
|
|
|
Raises:
|
|
ValueError: ``params`` contains sparse parameters or parameters
|
|
of varying dense types.
|
|
|
|
NOTE: This method can be removed once support for sparse parameters
|
|
and varying parameter types is added.
|
|
"""
|
|
typename = torch.typename(self._all_params[0])
|
|
if self._all_params[0].is_sparse:
|
|
raise ValueError(
|
|
"ZeroRedundancyOptimizer only supports using "
|
|
"the same dense type for all parameters but got "
|
|
f"{typename}"
|
|
)
|
|
for param in self._all_params[1:]:
|
|
other_typename = torch.typename(param)
|
|
if other_typename != typename:
|
|
raise ValueError(
|
|
"ZeroRedundancyOptimizer only supports "
|
|
"using the same dense type for all "
|
|
f"parameters but got both {typename} and "
|
|
f"{other_typename}"
|
|
)
|
|
|
|
def _get_is_trainable_mask(self) -> list[bool]:
|
|
r"""Return a boolean mask indicating if each parameter is trainable (``requires_grad``) or not."""
|
|
return list(map(_is_trainable, self._all_params))
|
|
|
|
def _init_local_optimizer(self) -> None:
|
|
r"""
|
|
Initialize this rank's local optimizer, responsible for its subset of the parameters.
|
|
|
|
The local optimizer is saved in ``self.optim``.
|
|
"""
|
|
assert self._optim_constructor is not None, (
|
|
"The local optimizer class has not been set"
|
|
)
|
|
|
|
param_groups = self._partition_parameters()[self.rank]
|
|
# `overlap_with_ddp=True` requires a local functional optimizer
|
|
if self._overlap_with_ddp:
|
|
# Functional optimizers only support a single parameter group and
|
|
# require passing in the parameters as a list
|
|
assert len(param_groups) == 1, (
|
|
"Initializing the local "
|
|
"functional optimizer with more than one parameter group"
|
|
)
|
|
params = param_groups[0]["params"]
|
|
# Try to pass `_allow_empty_param_list=True` to avoid erroring
|
|
if (
|
|
"_allow_empty_param_list"
|
|
in inspect.signature(self._optim_constructor).parameters
|
|
):
|
|
self.optim: Any = self._optim_constructor(
|
|
params, **self._optim_defaults, _allow_empty_param_list=True
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"%s does not support the argument "
|
|
"`_allow_empty_param_list`; ZeroRedundancyOptimizer may "
|
|
"error due to an empty parameter list",
|
|
self._optim_constructor,
|
|
)
|
|
self.optim: Any = self._optim_constructor(
|
|
params, **self._optim_defaults
|
|
) # type: ignore[no-redef]
|
|
|
|
# Log information about the DDP and ZeRO bucketing
|
|
if dist.get_debug_level() != dist.DebugLevel.OFF:
|
|
local_numel = sum(p.numel() for p in params)
|
|
num_assigned_buckets = len(
|
|
self._bucket_assignments_per_rank[self.global_rank]
|
|
)
|
|
logger.info(
|
|
"rank %s with %s parameters across %s buckets",
|
|
self.global_rank,
|
|
local_numel,
|
|
num_assigned_buckets,
|
|
)
|
|
if self.global_rank == 0:
|
|
logger.info(
|
|
"%s DDP buckets and %s bucket assignments",
|
|
len(self._overlap_info.params_per_bucket),
|
|
self._overlap_info.num_bucket_assignments,
|
|
)
|
|
else:
|
|
# NOTE: Passing `param_groups` into the local optimizer constructor
|
|
# bypasses the empty parameter list check
|
|
self.optim: Optimizer = self._optim_constructor(
|
|
param_groups, **self._optim_defaults
|
|
) # type: ignore[no-redef]
|
|
|
|
# TODO: Manually add `self.param_groups` if using a functional
|
|
# optimizer; remove this if/when the functional optimizers support
|
|
# multiple parameter groups
|
|
if self._overlap_with_ddp and not hasattr(self.optim, "param_groups"):
|
|
assert hasattr(self.optim, "param_group"), (
|
|
"The functional optimizer should set at least one of the "
|
|
"attributes `param_group` or `param_groups`"
|
|
)
|
|
self.optim.param_groups = [self.optim.param_group] # type: ignore[attr-defined]
|
|
|
|
self._sync_param_groups(self.optim.param_groups, self.param_groups)
|
|
|
|
def _init_zero_for_overlap(self) -> None:
|
|
r"""Perform a delayed initialization of the local optimizer and the supporting data structures."""
|
|
assert self._overlap_with_ddp, (
|
|
"`_init_zero_for_overlap()` should only be called when "
|
|
"`overlap_with_ddp=True`"
|
|
)
|
|
self._overlap_info.status = _OverlapStatus.INITIALIZED
|
|
self._clear_cache()
|
|
self._partition_parameters(self._overlap_info.params_per_rank)
|
|
self._build_ddp_param_buckets()
|
|
self._init_local_optimizer()
|
|
|
|
def _get_assigned_rank(self, bucket_index: int) -> int:
|
|
r"""
|
|
Return the single rank assigned to a :class:`DistributedDataParallel` gradient bucket.
|
|
|
|
Arguments:
|
|
bucket_index (int): index of the :class:`DistributedDataParallel`
|
|
bucket for which to get the assigned rank.
|
|
"""
|
|
assert not self._overlap_info.shard_buckets, (
|
|
"The bucket assignment requires global bucket information and "
|
|
"will be computed later; there should be no need to use this "
|
|
"method"
|
|
)
|
|
return bucket_index % self.world_size
|
|
|
|
def _check_overlap_initialized(self):
|
|
r"""
|
|
Check the delayed initialization depending on the value of ``overlap_with_ddp``.
|
|
|
|
The delayed initialization has occurred (see
|
|
:meth:`_init_zero_for_overlap`) if ``overlap_with_ddp=True``, and
|
|
raises a ``RuntimeError`` if not. This should preface methods that
|
|
should not be run before that delayed initialization.
|
|
|
|
Raises:
|
|
RuntimeError: if ``overlap_with_ddp=True`` and
|
|
:meth:`_init_zero_for_overlap` has not been called.
|
|
"""
|
|
if (
|
|
self._overlap_with_ddp
|
|
and self._overlap_info.status != _OverlapStatus.INITIALIZED
|
|
):
|
|
raise RuntimeError(
|
|
"This method should not be called until this "
|
|
"ZeroRedundancyOptimizer instance has been fully "
|
|
"initialized"
|
|
)
|
|
|
|
def _get_optimizer_constructor(self, optimizer_class: Any) -> Any:
|
|
r"""
|
|
Return the optimizer constructor using validation and transformation depending on ``overlap_with_ddp``.
|
|
|
|
Returns:
|
|
- ``optimizer_class`` if ``overlap_with_ddp=False`` and
|
|
``optimizer_class`` is not a functional optimizer.
|
|
- ``optimizer_class`` if ``overlap_with_ddp=True`` and
|
|
``optimizer_class`` is already a functional optimizer.
|
|
- The functional equivalent of ``optimizer_class`` if
|
|
``overlap_with_ddp=True`` and ``optimizer_class`` is not
|
|
already a functional optimizer (assuming the equivalent
|
|
exists).
|
|
|
|
Raises:
|
|
ValueError:
|
|
|
|
- if ``overlap_with_ddp=True`` but ``optimizer_class`` is
|
|
neither a functional optimizer nor translatable to a
|
|
functional optimizer.
|
|
- if ``overlap_with_ddp=False`` and ``optimizer_class`` is a
|
|
functional optimizer.
|
|
"""
|
|
functional_optims = functional_optim_map.values()
|
|
if not self._overlap_with_ddp:
|
|
if optimizer_class in functional_optims:
|
|
# Using a functional optimizer is only supported when
|
|
# `overlap_with_ddp=True`
|
|
raise ValueError(
|
|
f"Passing in a functional optimizer {optimizer_class} "
|
|
"when `overlap_with_ddp=False`"
|
|
)
|
|
else:
|
|
return optimizer_class
|
|
else:
|
|
if optimizer_class in functional_optims:
|
|
# Already a functional optimizer
|
|
return optimizer_class
|
|
elif optimizer_class in functional_optim_map:
|
|
# Translate the passed-in optimizer class to its functional
|
|
# equivalent if `overlap_with_ddp=True`
|
|
optim_constructor = functional_optim_map[optimizer_class]
|
|
logger.info(
|
|
"Using the functional optimizer %s "
|
|
"instead of %s since "
|
|
"`overlap_with_ddp=True`",
|
|
optim_constructor,
|
|
optimizer_class,
|
|
)
|
|
return optim_constructor
|
|
else:
|
|
raise ValueError(
|
|
"Using `ddp_with_overlap=True` requires using a "
|
|
"functional optimizer, but there is no supported functional "
|
|
f"optimizer equivalent for {optimizer_class}"
|
|
)
|