[1/n] Consolidate replicate and DDP: setup ufmt for distributed.py (#96597)

As we already enabled ufmt for composable APIs in https://github.com/pytorch/pytorch/pull/90873, it seems a good idea to enable ufmt for other distributed APIs as well. This change setup ufmt for DDP.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96597
Approved by: https://github.com/rohan-varma
This commit is contained in:
Charlie Yan
2023-03-17 03:12:23 +00:00
committed by PyTorch MergeBot
parent 24ce3a7c34
commit 13538c88b3
2 changed files with 80 additions and 111 deletions

View File

@ -862,6 +862,7 @@ include_patterns = [
'test/test_value_ranges.py',
'torch/utils/_sympy/interp.py',
'torch/utils/_sympy/reference.py',
'torch/nn/parallel/distributed.py',
]
command = [
'python3',

View File

@ -1,6 +1,5 @@
import copy
import functools
from collections import defaultdict, deque
import inspect
import itertools
import logging
@ -8,33 +7,30 @@ import os
import sys
import warnings
import weakref
from collections import defaultdict, deque
from contextlib import contextmanager
from dataclasses import dataclass, fields, is_dataclass
from enum import Enum, auto
from typing import Callable, Any, Type, Tuple, Optional, List
from enum import auto, Enum
from typing import Any, Callable, List, Optional, Tuple, Type
import torch
import torch.distributed as dist
from torch.autograd import Function, Variable
from torch.distributed.algorithms.join import (
Join,
Joinable,
JoinHook,
)
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from torch.utils._pytree import tree_flatten, tree_unflatten
RPC_AVAILABLE = False
if dist.is_available():
from torch.distributed.distributed_c10d import _get_default_group, ReduceOp
from torch.distributed.utils import (
_verify_param_shape_across_processes,
_alloc_storage,
_apply_to_tensors,
_free_storage,
_sync_module_states,
_to_kwargs,
_apply_to_tensors,
_alloc_storage,
_free_storage,
_verify_param_shape_across_processes,
)
from torch.distributed.distributed_c10d import ReduceOp, _get_default_group
if torch.distributed.rpc.is_available():
RPC_AVAILABLE = True
from torch.distributed.rpc import RRef
@ -48,6 +44,7 @@ __all__ = ["DistributedDataParallel"]
logger = logging.getLogger(__name__)
@dataclass
class _MixedPrecision:
"""
@ -90,26 +87,28 @@ class _MixedPrecision:
# in full precision. For DDP, this can be implemented by not performing the
# parameter cast for BN and LN units.
def _cast_buffers(mixed_precision_config, root_module):
"""
Casts buffers to the given ``buffer_dtype``.
"""
for buf in root_module.buffers():
if hasattr(buf, '_ddp_ignored') and buf._ddp_ignored:
if hasattr(buf, "_ddp_ignored") and buf._ddp_ignored:
continue
buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype)
def _setup_mixed_precision_params(mixed_precision_config, root_module):
"""
Creates and frees storage for the mixed precision parameters.
"""
for param in root_module.parameters():
# Do not setup mixed precision for DDP ignored parameters.
if hasattr(param, '_ddp_ignored') and param._ddp_ignored:
if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
continue
if not hasattr(param, '_mp_param'):
if not hasattr(param, "_mp_param"):
param._mp_param = torch.zeros_like(
param,
device=param.device,
@ -121,6 +120,7 @@ def _setup_mixed_precision_params(mixed_precision_config, root_module):
# back to at the end of forward / backward.
param._fp_param = param.data
def _cast_forward_inputs(
input_dtype: Optional[torch.dtype],
*args: Any,
@ -130,15 +130,14 @@ def _cast_forward_inputs(
Casts input args and kwargs to the given input_dtype. Note that only
floating point tensors are cast.
"""
def cast_fn(x: torch.Tensor) -> torch.Tensor:
if not torch.is_floating_point(x) or x.dtype == input_dtype:
return x
return x.to(input_dtype)
return (
_apply_to_tensors(cast_fn, args),
_apply_to_tensors(cast_fn, kwargs)
)
return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))
def _tree_flatten_with_rref(output):
output_is_rref = RPC_AVAILABLE and isinstance(output, RRef)
@ -265,8 +264,7 @@ class _DDPSink(Function):
ctx.reducer = reducer
ctx.state_dict = state_dict
ret = tuple(
inp.clone() if isinstance(inp, torch.Tensor) else inp
for inp in inputs
inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
)
return ret
@ -274,10 +272,7 @@ class _DDPSink(Function):
def backward(ctx, *grad_outputs):
# Enqueue delay allreduce for static graph training on the first
# iteration.
if (
ctx.state_dict["static_graph"]
and ctx.state_dict["num_iterations"] == 1
):
if ctx.state_dict["static_graph"] and ctx.state_dict["num_iterations"] == 1:
Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc]
ctx.reducer._delay_all_reduce
)
@ -316,9 +311,7 @@ class _DDPJoinHook(JoinHook):
ddp._check_and_sync_module_buffers()
# Check if need to sync in the backward pass
work = ddp._check_global_requires_backward_grad_sync(
is_joined_rank=True
)
work = ddp._check_global_requires_backward_grad_sync(is_joined_rank=True)
work.wait()
should_sync_backwards = work.result()[0].item() != 0
# Forward parameter sync is disabled in the next iteration if we
@ -666,11 +659,13 @@ class DistributedDataParallel(Module, Joinable):
super().__init__()
Joinable.__init__(self)
self.logger = None
if bool(delay_all_reduce_named_params is not None) != bool(param_to_hook_all_reduce is not None):
if bool(delay_all_reduce_named_params is not None) != bool(
param_to_hook_all_reduce is not None
):
self._log_and_throw(
ValueError,
"delay_all_reduce_named_params and param_to_hook_all_reduce "
"need to be set at the same time."
"need to be set at the same time.",
)
self._delay_all_reduce_params = []
@ -683,7 +678,11 @@ class DistributedDataParallel(Module, Joinable):
self.parameters_to_ignore.add(name)
self._delay_all_reduce_params.append(param)
self._module_parameters = [p for n, p in module.named_parameters() if n not in self.parameters_to_ignore]
self._module_parameters = [
p
for n, p in module.named_parameters()
if n not in self.parameters_to_ignore
]
if not any((p.requires_grad for p in self._module_parameters)):
if len(self._delay_all_reduce_params):
logger.info("Delay the AllReduce of all parameters.")
@ -700,8 +699,12 @@ class DistributedDataParallel(Module, Joinable):
"device_ids can only be None or contain a single element.",
)
self.is_multi_device_module = len({p.device for p in self._module_parameters}) > 1
distinct_device_types = {p.device.type for p in self._module_parameters if p.device is not None}
self.is_multi_device_module = (
len({p.device for p in self._module_parameters}) > 1
)
distinct_device_types = {
p.device.type for p in self._module_parameters if p.device is not None
}
if len(distinct_device_types) != 1:
self._log_and_throw(
ValueError,
@ -813,9 +816,7 @@ class DistributedDataParallel(Module, Joinable):
params_and_buffers_to_ignore=self.parameters_to_ignore,
)
# In debug mode, build a mapping of parameter index -> parameter.
param_to_name_mapping = self._build_debug_param_to_name_mapping(
parameters
)
param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
# Builds reducer.
self._ddp_init_helper(
@ -839,15 +840,19 @@ class DistributedDataParallel(Module, Joinable):
# before running computation.
for module in self.module.modules():
module.register_forward_pre_hook(
self._module_wait_for_copy_hook, prepend=False, with_kwargs=True,
self._module_wait_for_copy_hook,
prepend=False,
with_kwargs=True,
)
# Set up callbacks in backward to upcast and use full precision
# params. TODO (rohan-varma): Make this compose with general
# comm hooks and apply_optimizer_in_backward. Importing inline to
# avoid circular import issue.
from torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks import (
_reducer_allreduce_and_upcast_hook, _AllreduceUpcastHookState
_AllreduceUpcastHookState,
_reducer_allreduce_and_upcast_hook,
)
upcast_hook_state = _AllreduceUpcastHookState(
ddp_weakref=weakref.ref(self),
upcast_stream=torch.cuda.Stream(),
@ -892,7 +897,9 @@ class DistributedDataParallel(Module, Joinable):
def _delayed_all_reduce(grad):
self._delay_grad_buffer.div_(world_size) # type: ignore[union-attr]
_ = dist.all_reduce(self._delay_grad_buffer, group=process_group, async_op=True)
_ = dist.all_reduce(
self._delay_grad_buffer, group=process_group, async_op=True
)
return grad
param_to_hook_all_reduce.register_hook(_delayed_all_reduce)
@ -931,13 +938,13 @@ class DistributedDataParallel(Module, Joinable):
# ping https://github.com/pytorch/pytorch/issues/90052.
# NOTE: we use self._module_parameters instead of .parameters() since
# the former excludes ignored (non-DDP managed) parameters.
if any(
hasattr(p, '_in_backward_optimizers') for p in self._module_parameters
):
if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters):
# Remove hooks that apply_optim_in_backward had registered because
# DDP customizes how optimizer is overlapped with backward due to
# the allreduce.
param_to_handle_map = dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map
param_to_handle_map = (
dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map
)
for p in self._module_parameters:
for handle in param_to_handle_map.get(p, []):
handle.remove()
@ -947,8 +954,9 @@ class DistributedDataParallel(Module, Joinable):
# Note: importing in function, otherwise this will cause a circular
# import.
from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
_apply_optim_in_backward_hook
_apply_optim_in_backward_hook,
)
self.register_comm_hook(
(reducer_weakref, self.process_group),
_apply_optim_in_backward_hook(
@ -981,11 +989,7 @@ class DistributedDataParallel(Module, Joinable):
"""
self.reducer._autograd_hook(idx) # type: ignore[attr-defined]
def _root_copy_hook(
self,
*args: Any,
**kwargs: Any
) -> None:
def _root_copy_hook(self, *args: Any, **kwargs: Any) -> None:
"""
When training with DDP mixed precision, this root pre-forward hook kicks
off low precision copies on a separate stream and creates respective
@ -999,7 +1003,7 @@ class DistributedDataParallel(Module, Joinable):
for submodule in self.module.modules():
for param in submodule.parameters(recurse=False):
# Do not cast DDP ignored parameters.
if hasattr(param, '_ddp_ignored') and param._ddp_ignored:
if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
continue
_alloc_storage(param._mp_param, param.size())
# copy() implicitly casts to low precision
@ -1042,10 +1046,7 @@ class DistributedDataParallel(Module, Joinable):
event.wait(stream=torch.cuda.current_stream())
for p in module.parameters(recurse=False):
# Don't register hooks if param does not require grad
if (
not p.requires_grad
or (hasattr(p, '_ddp_ignored') and p._ddp_ignored)
):
if not p.requires_grad or (hasattr(p, "_ddp_ignored") and p._ddp_ignored):
continue
# We need to register autograd hook here instead of DDP's ctor
# since we're working with the low precision param. Register them
@ -1183,9 +1184,7 @@ class DistributedDataParallel(Module, Joinable):
self.__dict__.setdefault("require_backward_grad_sync", True)
parameters, expect_sparse_gradient = self._build_params_for_reducer()
# In debug mode, build a mapping of parameter index -> parameter.
param_to_name_mapping = self._build_debug_param_to_name_mapping(
parameters
)
param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
# Builds reducer.
self._ddp_init_helper(
parameters,
@ -1211,8 +1210,7 @@ class DistributedDataParallel(Module, Joinable):
# parameters through _former_parameters.
for param_name, param in module.named_parameters(recurse=False)
if param.requires_grad
and f"{module_name}.{param_name}"
not in self.parameters_to_ignore
and f"{module_name}.{param_name}" not in self.parameters_to_ignore
]
]
@ -1238,8 +1236,7 @@ class DistributedDataParallel(Module, Joinable):
# Build list of booleans indicating whether or not to expect sparse
# gradients for the corresponding parameters.
expect_sparse_gradient = [
produces_sparse_gradient(module)
for module, _ in modules_and_parameters
produces_sparse_gradient(module) for module, _ in modules_and_parameters
]
self._assign_modules_buffers()
@ -1265,17 +1262,14 @@ class DistributedDataParallel(Module, Joinable):
]
# Dict[str, tensor] representing module buffers not ignored by DDP.
self.named_module_buffers = {
buffer_name: buffer
for (buffer, buffer_name) in named_module_buffers
buffer_name: buffer for (buffer, buffer_name) in named_module_buffers
}
def _build_debug_param_to_name_mapping(self, parameters):
if dist.get_debug_level() == dist.DebugLevel.OFF:
return {}
param_to_param_index = {
parameters[i]: i for i in range(len(parameters))
}
param_to_param_index = {parameters[i]: i for i in range(len(parameters))}
param_set = set(parameters)
param_index_to_param_fqn = {}
for module_name, module in self.module.named_modules():
@ -1400,9 +1394,7 @@ class DistributedDataParallel(Module, Joinable):
)
args, kwargs = inputs[0], kwargs[0] # type: ignore[index]
# Cast inputs to reduced precision if needed.
if (
self.mixed_precision is not None
):
if self.mixed_precision is not None:
args, kwargs = _cast_forward_inputs(
self.mixed_precision.param_dtype,
*args,
@ -1413,9 +1405,7 @@ class DistributedDataParallel(Module, Joinable):
else:
# Cast inputs to reduced precision if needed.
# TODO (rohan-varma) test this codepath.
if (
self.mixed_precision is not None
):
if self.mixed_precision is not None:
inputs, kwargs = _cast_forward_inputs(
self.mixed_precision.param_dtype,
*inputs,
@ -1446,9 +1436,7 @@ class DistributedDataParallel(Module, Joinable):
self._delay_grad_buffer.zero_()
def forward(self, *inputs, **kwargs):
with torch.autograd.profiler.record_function(
"DistributedDataParallel.forward"
):
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
if self._delay_all_reduce_all_params:
output = self.module.forward(*inputs, **kwargs)
self._clear_grad_buffer()
@ -1475,9 +1463,7 @@ class DistributedDataParallel(Module, Joinable):
# during forward computation.
# This should be called only once during whole training period.
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
logger.info(
"Reducer buckets have been rebuilt in this iteration."
)
logger.info("Reducer buckets have been rebuilt in this iteration.")
self._has_rebuilt_buckets = True
# sync params according to location (before/after forward) user
@ -1487,9 +1473,7 @@ class DistributedDataParallel(Module, Joinable):
if self._join_config.enable:
# Notify joined ranks whether they should sync in backwards pass or not.
self._check_global_requires_backward_grad_sync(
is_joined_rank=False
)
self._check_global_requires_backward_grad_sync(is_joined_rank=False)
output = self._run_ddp_forward(*inputs, **kwargs)
@ -1507,9 +1491,7 @@ class DistributedDataParallel(Module, Joinable):
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters and not self.static_graph:
# Do not need to populate this for static graph.
self.reducer.prepare_for_backward(
list(_find_tensors(output))
)
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
@ -1593,9 +1575,7 @@ class DistributedDataParallel(Module, Joinable):
# the models have buffers that should be synchronized in the forward pass.
def _check_and_sync_module_buffers(self):
if self._check_sync_bufs_pre_fwd():
authoritative_rank = self._find_common_rank(
self._distributed_rank, False
)
authoritative_rank = self._find_common_rank(self._distributed_rank, False)
self._sync_module_buffers(authoritative_rank)
# When running in join model, agrees upon a common rank and broadcast model
@ -1777,9 +1757,7 @@ class DistributedDataParallel(Module, Joinable):
cases for possibly better results.
Default is ``True``.
"""
divide_by_initial_world_size = kwargs.get(
"divide_by_initial_world_size", True
)
divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True)
return _DDPJoinHook(
self, divide_by_initial_world_size=divide_by_initial_world_size
)
@ -1946,9 +1924,7 @@ class DistributedDataParallel(Module, Joinable):
self.logger._set_comm_hook_name(str(comm_hook_type))
dist._register_builtin_comm_hook(self.reducer, comm_hook_type)
def _register_fused_optim(
self, optim: Type, *args, optim_params=None, **kwargs
):
def _register_fused_optim(self, optim: Type, *args, optim_params=None, **kwargs):
r"""
Registers an optimizer with DDP such that the optimization for a
parameter will run immediately when that parameter's gradient is
@ -2008,13 +1984,9 @@ class DistributedDataParallel(Module, Joinable):
"""
# Note: importing in function, otherwise this will cause a circular
# import as optimizer_overlap module needs to import DistributedDataParallel.
from torch.distributed.algorithms._optimizer_overlap import (
_as_overlapped_optim,
)
from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim
overlapped_optim = _as_overlapped_optim(
optim, optim_params, *args, **kwargs
)
overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs)
try:
overlapped_optim.register_ddp(self)
except NotImplementedError as e:
@ -2088,9 +2060,7 @@ class DistributedDataParallel(Module, Joinable):
def _sync_module_buffers(self, authoritative_rank):
if not hasattr(self, "buffer_hook"):
self._default_broadcast_coalesced(
authoritative_rank=authoritative_rank
)
self._default_broadcast_coalesced(authoritative_rank=authoritative_rank)
else:
hook = self.buffer_hook.buffer_comm_hook
state = self.buffer_hook.buffer_comm_hook_state
@ -2111,9 +2081,7 @@ class DistributedDataParallel(Module, Joinable):
if bucket_size is None:
bucket_size = self.broadcast_bucket_size
self._distributed_broadcast_coalesced(
bufs, bucket_size, authoritative_rank
)
self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank)
def _passing_sync_batchnorm_handle(self, module):
for layer in module.modules():
@ -2126,9 +2094,7 @@ class DistributedDataParallel(Module, Joinable):
def _check_comm_hook(self, hook):
if not callable(hook):
self._log_and_throw(
TypeError, "Communication hook must be callable."
)
self._log_and_throw(TypeError, "Communication hook must be callable.")
sig = inspect.signature(hook)
if (
@ -2173,8 +2139,10 @@ class DistributedDataParallel(Module, Joinable):
@staticmethod
def _get_data_parallel_params(module, named_params=False):
for param in (module.parameters() if not named_params else module.named_parameters()):
if not hasattr(param, '_ddp_ignored'):
for param in (
module.parameters() if not named_params else module.named_parameters()
):
if not hasattr(param, "_ddp_ignored"):
yield param
@staticmethod