Fixes #155017 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155298 Approved by: https://github.com/svekars Co-authored-by: Svetlana Karslioglu <svekars@meta.com>
7.3 KiB
DDP Communication Hooks
DDP communication hook is a generic interface to control how to communicate gradients across workers by overriding the vanilla allreduce in DistributedDataParallel. A few built-in communication hooks are provided, and users can easily apply any of these hooks to optimize communication. Besides, the hook interface can also support user-defined communication strategies for more advanced use cases.
How to Use a Communication Hook?
To use a communication hook, the user just needs to let the DDP model register the hook before the training loop as below.
{func}torch.nn.parallel.DistributedDataParallel.register_comm_hook
What Does a Communication Hook Operate On?
A communication hook provides a flexible way to allreduce gradients.
Therefore, it mainly operates on the gradients on each replica before allreduce,
which are bucketized to increase the overlap between communication and computation.
Particularly, {class}torch.distributed.GradBucket
represents a bucket of gradient tensors to be allreduced.
.. autoclass:: torch.distributed.GradBucket
.. autofunction:: torch.distributed.GradBucket.index
.. autofunction:: torch.distributed.GradBucket.buffer
.. autofunction:: torch.distributed.GradBucket.gradients
.. autofunction:: torch.distributed.GradBucket.is_last
.. autofunction:: torch.distributed.GradBucket.set_buffer
.. autofunction:: torch.distributed.GradBucket.parameters
Default Communication Hooks
Default communication hooks are simple stateless hooks, so the input state
in register_comm_hook
is either a process group or None
.
The input bucket
is a {class}torch.distributed.GradBucket
object.
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks
.. autofunction:: allreduce_hook
.. autofunction:: fp16_compress_hook
.. autofunction:: bf16_compress_hook
Additionally, a communication hook wrapper is provided to support {meth}~fp16_compress_hook
or {meth}~bf16_compress_hook
as a wrapper,
which can be combined with other communication hooks.
.. autofunction:: fp16_compress_wrapper
.. autofunction:: bf16_compress_wrapper
PowerSGD Communication Hook
PowerSGD (Vogels et al., NeurIPS 2019) is a gradient compression algorithm, which can provide very high compression rates and accelerate bandwidth-bound distributed training. This algorithm needs to maintain both some hyperparameters and the internal state. Therefore, PowerSGD communication hook is a stateful hook, and the user needs to provide a state object defined as below.
PowerSGD State
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook
.. autoclass:: PowerSGDState
PowerSGD Hooks
PowerSGD typically requires extra memory of the same size as the model's
gradients to enable error feedback, which can compensate for biased
compressed communication and improve accuracy.
PowerSGD hooks may conflict with [Apex automatic mixed precision package](https://github.com/NVIDIA/apex).
Please use PyTorch [native automatic mixed precision package](https://pytorch.org/docs/stable/amp.html)
instead.
.. autofunction:: powerSGD_hook
.. autofunction:: batched_powerSGD_hook
Debugging Communication Hooks
As the name implies, debugging communication hooks are only used for debugging and performance optimization purpose.
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks
Debugging communication hooks do not necessarily output the correct results.
.. autofunction:: noop_hook
Checkpointing of Communication Hooks
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook
A stateful communication hook can be saved as a part of model checkpointing to enable trainer restarts.
To make a hook serializable, __setstate__
and __getstate__
should be defined.
`__getstate__` should exclude non-serializable attributes from a returned dictionary.
`__setstate__` should properly initialize non-serializable attributes, excluded from a provided `state`.
{class}PowerSGDState
has __setstate__
and __getstate__
implemented and can be used as a reference.
.. class:: PowerSGDState
:noindex:
.. automethod:: PowerSGDState.__getstate__
.. automethod:: PowerSGDState.__setstate__
Here is a simple, end-to-end example of saving and reloading PowerSGD state and hook.
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(24,24)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(24,12)
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def run_demo(demo_fn, world_size):
mp.spawn(
demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
def demo_serialization(rank, world_size):
setup(rank, world_size)
CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt"
model = SimpleModel().to(rank)
ddp_model = DistributedDataParallel(model, device_ids=[rank])
powersgd_hook = powerSGD.powerSGD_hook
powersgd_state = powerSGD.PowerSGDState(process_group=None)
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
ddp_model.register_comm_hook(powersgd_state, powersgd_hook)
state = {
'state_dict': ddp_model.state_dict(),
'comm_hook': powersgd_hook,
'comm_hook_state': powersgd_state}
if rank == 0:
torch.save(state, CHECKPOINT)
dist.barrier()
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
checkpoint = torch.load(CHECKPOINT, map_location=map_location)
new_ddp_model = DistributedDataParallel(SimpleModel().to(rank), device_ids=[rank])
new_ddp_model.load_state_dict(checkpoint['state_dict'])
powersgd_hook = checkpoint['comm_hook']
powersgd_state = checkpoint['comm_hook_state']
new_ddp_model.register_comm_hook(powersgd_state, powersgd_hook)
if rank == 0:
os.remove(CHECKPOINT)
cleanup()
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_serialization, world_size)
Acknowledgements
Many thanks to PowerSGD paper author Thijs Vogels for the code review on PowerSGD communication hook, as well as the comparison experiments, which show that the performance of PowerSGD communication hook is on par with the implementation in the original paper.