Files
pytorch/torch/distributed/optim/post_localSGD_optimizer.py
joncrall 4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00

111 lines
4.3 KiB
Python

import torch
import torch.distributed.algorithms.model_averaging.averagers as averagers
import warnings
class PostLocalSGDOptimizer(torch.optim.Optimizer):
r"""
Wraps an arbitrary :class:`torch.optim.Optimizer` and runs `post-local SGD <https://arxiv.org/abs/1808.07217>`_,
This optimizer runs local optimizer at every step.
After the warm-up stage, it averages parameters periodically afer the local optimizer is applied.
Args:
optim: The local optimizer.
averager: A model averager instance to run post-localSGD algorithm.
Example::
>>> # xdoctest: +SKIP("undefined variables")
>>> import torch
>>> import torch.distributed as dist
>>> import torch.distributed.algorithms.model_averaging.averagers as averagers
>>> import torch.nn as nn
>>> from torch.distributed.optim import PostLocalSGDOptimizer
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
>>> PostLocalSGDState,
>>> post_localSGD_hook,
>>> )
>>>
>>> model = nn.parallel.DistributedDataParallel(
>>> module, device_ids=[rank], output_device=rank
>>> )
>>>
>>> # Register a post-localSGD communication hook.
>>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>>
>>> # Create a post-localSGD optimizer that wraps a local optimizer.
>>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as
>>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``.
>>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
>>> opt = PostLocalSGDOptimizer(
>>> optim=local_optim,
>>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
>>> )
>>>
>>> # In the first 100 steps, DDP runs global gradient averaging at every step.
>>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default),
>>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer.
>>> for step in range(0, 200):
>>> opt.zero_grad()
>>> loss = loss_fn(output, labels)
>>> loss.backward()
>>> opt.step()
"""
def __init__(
self,
optim: torch.optim.Optimizer,
averager: averagers.ModelAverager
):
self.optim = optim
self.param_groups = self.optim.param_groups
self.averager = averager
@property
def state(self):
return self.optim.state
def __repr__(self):
return self.optim.__repr__()
def state_dict(self):
r"""
This is the same as :class:`torch.optim.Optimizer` :meth:`state_dict`,
but adds an extra entry to record model averager's step to the checkpoint
to ensure reload does not cause unnecessary warm up again.
"""
optim_state_dict = self.optim.state_dict()
optim_state_dict['step'] = self.averager.step
return optim_state_dict
def load_state_dict(self, state_dict):
r"""
This is the same as :class:`torch.optim.Optimizer` :meth:`load_state_dict`,
but also restores model averager's step value to the one
saved in the provided ``state_dict``.
If there is no ``"step"`` entry in ``state_dict``,
it will raise a warning and initialize the model averager's step to 0.
"""
self.optim.load_state_dict(state_dict)
if 'step' in state_dict:
self.averager.step = state_dict['step']
else:
warnings.warn("Loaded state dict does not contain a step counter for an averager. "
"Setting step counter to 0.")
self.averager.step = 0
def step(self):
r"""
Performs a single optimization step (parameter update).
"""
self.optim.step()
self.averager.average_parameters(params=self.param_groups)
def zero_grad(self, set_to_none: bool = False): # type: ignore[override]
self.optim.zero_grad(set_to_none=set_to_none)
def add_param_group(self, param_group):
self.optim.add_param_group(param_group)