From 13538c88b38b8fe038e344ab0a588bb2462193c9 Mon Sep 17 00:00:00 2001 From: Charlie Yan Date: Fri, 17 Mar 2023 03:12:23 +0000 Subject: [PATCH] [1/n] Consolidate `replicate` and `DDP`: setup ufmt for `distributed.py` (#96597) As we already enabled ufmt for composable APIs in https://github.com/pytorch/pytorch/pull/90873, it seems a good idea to enable ufmt for other distributed APIs as well. This change setup ufmt for DDP. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96597 Approved by: https://github.com/rohan-varma --- .lintrunner.toml | 1 + torch/nn/parallel/distributed.py | 190 +++++++++++++------------------ 2 files changed, 80 insertions(+), 111 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 88ef032f0c60..2b2a25e02d13 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -862,6 +862,7 @@ include_patterns = [ 'test/test_value_ranges.py', 'torch/utils/_sympy/interp.py', 'torch/utils/_sympy/reference.py', + 'torch/nn/parallel/distributed.py', ] command = [ 'python3', diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 8caf687ab18c..74683a962d11 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -1,6 +1,5 @@ import copy import functools -from collections import defaultdict, deque import inspect import itertools import logging @@ -8,33 +7,30 @@ import os import sys import warnings import weakref +from collections import defaultdict, deque from contextlib import contextmanager from dataclasses import dataclass, fields, is_dataclass -from enum import Enum, auto -from typing import Callable, Any, Type, Tuple, Optional, List +from enum import auto, Enum +from typing import Any, Callable, List, Optional, Tuple, Type import torch import torch.distributed as dist from torch.autograd import Function, Variable -from torch.distributed.algorithms.join import ( - Join, - Joinable, - JoinHook, -) +from torch.distributed.algorithms.join import Join, Joinable, JoinHook from torch.utils._pytree import tree_flatten, tree_unflatten RPC_AVAILABLE = False if dist.is_available(): + from torch.distributed.distributed_c10d import _get_default_group, ReduceOp from torch.distributed.utils import ( - _verify_param_shape_across_processes, + _alloc_storage, + _apply_to_tensors, + _free_storage, _sync_module_states, _to_kwargs, - _apply_to_tensors, - _alloc_storage, - _free_storage, + _verify_param_shape_across_processes, ) - from torch.distributed.distributed_c10d import ReduceOp, _get_default_group if torch.distributed.rpc.is_available(): RPC_AVAILABLE = True from torch.distributed.rpc import RRef @@ -48,6 +44,7 @@ __all__ = ["DistributedDataParallel"] logger = logging.getLogger(__name__) + @dataclass class _MixedPrecision: """ @@ -90,26 +87,28 @@ class _MixedPrecision: # in full precision. For DDP, this can be implemented by not performing the # parameter cast for BN and LN units. + def _cast_buffers(mixed_precision_config, root_module): """ Casts buffers to the given ``buffer_dtype``. """ for buf in root_module.buffers(): - if hasattr(buf, '_ddp_ignored') and buf._ddp_ignored: + if hasattr(buf, "_ddp_ignored") and buf._ddp_ignored: continue buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype) + def _setup_mixed_precision_params(mixed_precision_config, root_module): """ Creates and frees storage for the mixed precision parameters. """ for param in root_module.parameters(): # Do not setup mixed precision for DDP ignored parameters. - if hasattr(param, '_ddp_ignored') and param._ddp_ignored: + if hasattr(param, "_ddp_ignored") and param._ddp_ignored: continue - if not hasattr(param, '_mp_param'): + if not hasattr(param, "_mp_param"): param._mp_param = torch.zeros_like( param, device=param.device, @@ -121,6 +120,7 @@ def _setup_mixed_precision_params(mixed_precision_config, root_module): # back to at the end of forward / backward. param._fp_param = param.data + def _cast_forward_inputs( input_dtype: Optional[torch.dtype], *args: Any, @@ -130,15 +130,14 @@ def _cast_forward_inputs( Casts input args and kwargs to the given input_dtype. Note that only floating point tensors are cast. """ + def cast_fn(x: torch.Tensor) -> torch.Tensor: if not torch.is_floating_point(x) or x.dtype == input_dtype: return x return x.to(input_dtype) - return ( - _apply_to_tensors(cast_fn, args), - _apply_to_tensors(cast_fn, kwargs) - ) + return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs)) + def _tree_flatten_with_rref(output): output_is_rref = RPC_AVAILABLE and isinstance(output, RRef) @@ -265,8 +264,7 @@ class _DDPSink(Function): ctx.reducer = reducer ctx.state_dict = state_dict ret = tuple( - inp.clone() if isinstance(inp, torch.Tensor) else inp - for inp in inputs + inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs ) return ret @@ -274,10 +272,7 @@ class _DDPSink(Function): def backward(ctx, *grad_outputs): # Enqueue delay allreduce for static graph training on the first # iteration. - if ( - ctx.state_dict["static_graph"] - and ctx.state_dict["num_iterations"] == 1 - ): + if ctx.state_dict["static_graph"] and ctx.state_dict["num_iterations"] == 1: Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc] ctx.reducer._delay_all_reduce ) @@ -316,9 +311,7 @@ class _DDPJoinHook(JoinHook): ddp._check_and_sync_module_buffers() # Check if need to sync in the backward pass - work = ddp._check_global_requires_backward_grad_sync( - is_joined_rank=True - ) + work = ddp._check_global_requires_backward_grad_sync(is_joined_rank=True) work.wait() should_sync_backwards = work.result()[0].item() != 0 # Forward parameter sync is disabled in the next iteration if we @@ -666,11 +659,13 @@ class DistributedDataParallel(Module, Joinable): super().__init__() Joinable.__init__(self) self.logger = None - if bool(delay_all_reduce_named_params is not None) != bool(param_to_hook_all_reduce is not None): + if bool(delay_all_reduce_named_params is not None) != bool( + param_to_hook_all_reduce is not None + ): self._log_and_throw( ValueError, "delay_all_reduce_named_params and param_to_hook_all_reduce " - "need to be set at the same time." + "need to be set at the same time.", ) self._delay_all_reduce_params = [] @@ -683,7 +678,11 @@ class DistributedDataParallel(Module, Joinable): self.parameters_to_ignore.add(name) self._delay_all_reduce_params.append(param) - self._module_parameters = [p for n, p in module.named_parameters() if n not in self.parameters_to_ignore] + self._module_parameters = [ + p + for n, p in module.named_parameters() + if n not in self.parameters_to_ignore + ] if not any((p.requires_grad for p in self._module_parameters)): if len(self._delay_all_reduce_params): logger.info("Delay the AllReduce of all parameters.") @@ -700,8 +699,12 @@ class DistributedDataParallel(Module, Joinable): "device_ids can only be None or contain a single element.", ) - self.is_multi_device_module = len({p.device for p in self._module_parameters}) > 1 - distinct_device_types = {p.device.type for p in self._module_parameters if p.device is not None} + self.is_multi_device_module = ( + len({p.device for p in self._module_parameters}) > 1 + ) + distinct_device_types = { + p.device.type for p in self._module_parameters if p.device is not None + } if len(distinct_device_types) != 1: self._log_and_throw( ValueError, @@ -813,9 +816,7 @@ class DistributedDataParallel(Module, Joinable): params_and_buffers_to_ignore=self.parameters_to_ignore, ) # In debug mode, build a mapping of parameter index -> parameter. - param_to_name_mapping = self._build_debug_param_to_name_mapping( - parameters - ) + param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters) # Builds reducer. self._ddp_init_helper( @@ -839,15 +840,19 @@ class DistributedDataParallel(Module, Joinable): # before running computation. for module in self.module.modules(): module.register_forward_pre_hook( - self._module_wait_for_copy_hook, prepend=False, with_kwargs=True, + self._module_wait_for_copy_hook, + prepend=False, + with_kwargs=True, ) # Set up callbacks in backward to upcast and use full precision # params. TODO (rohan-varma): Make this compose with general # comm hooks and apply_optimizer_in_backward. Importing inline to # avoid circular import issue. from torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks import ( - _reducer_allreduce_and_upcast_hook, _AllreduceUpcastHookState + _AllreduceUpcastHookState, + _reducer_allreduce_and_upcast_hook, ) + upcast_hook_state = _AllreduceUpcastHookState( ddp_weakref=weakref.ref(self), upcast_stream=torch.cuda.Stream(), @@ -892,7 +897,9 @@ class DistributedDataParallel(Module, Joinable): def _delayed_all_reduce(grad): self._delay_grad_buffer.div_(world_size) # type: ignore[union-attr] - _ = dist.all_reduce(self._delay_grad_buffer, group=process_group, async_op=True) + _ = dist.all_reduce( + self._delay_grad_buffer, group=process_group, async_op=True + ) return grad param_to_hook_all_reduce.register_hook(_delayed_all_reduce) @@ -931,13 +938,13 @@ class DistributedDataParallel(Module, Joinable): # ping https://github.com/pytorch/pytorch/issues/90052. # NOTE: we use self._module_parameters instead of .parameters() since # the former excludes ignored (non-DDP managed) parameters. - if any( - hasattr(p, '_in_backward_optimizers') for p in self._module_parameters - ): + if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters): # Remove hooks that apply_optim_in_backward had registered because # DDP customizes how optimizer is overlapped with backward due to # the allreduce. - param_to_handle_map = dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map + param_to_handle_map = ( + dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map + ) for p in self._module_parameters: for handle in param_to_handle_map.get(p, []): handle.remove() @@ -947,8 +954,9 @@ class DistributedDataParallel(Module, Joinable): # Note: importing in function, otherwise this will cause a circular # import. from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import ( - _apply_optim_in_backward_hook + _apply_optim_in_backward_hook, ) + self.register_comm_hook( (reducer_weakref, self.process_group), _apply_optim_in_backward_hook( @@ -981,11 +989,7 @@ class DistributedDataParallel(Module, Joinable): """ self.reducer._autograd_hook(idx) # type: ignore[attr-defined] - def _root_copy_hook( - self, - *args: Any, - **kwargs: Any - ) -> None: + def _root_copy_hook(self, *args: Any, **kwargs: Any) -> None: """ When training with DDP mixed precision, this root pre-forward hook kicks off low precision copies on a separate stream and creates respective @@ -999,7 +1003,7 @@ class DistributedDataParallel(Module, Joinable): for submodule in self.module.modules(): for param in submodule.parameters(recurse=False): # Do not cast DDP ignored parameters. - if hasattr(param, '_ddp_ignored') and param._ddp_ignored: + if hasattr(param, "_ddp_ignored") and param._ddp_ignored: continue _alloc_storage(param._mp_param, param.size()) # copy() implicitly casts to low precision @@ -1042,10 +1046,7 @@ class DistributedDataParallel(Module, Joinable): event.wait(stream=torch.cuda.current_stream()) for p in module.parameters(recurse=False): # Don't register hooks if param does not require grad - if ( - not p.requires_grad - or (hasattr(p, '_ddp_ignored') and p._ddp_ignored) - ): + if not p.requires_grad or (hasattr(p, "_ddp_ignored") and p._ddp_ignored): continue # We need to register autograd hook here instead of DDP's ctor # since we're working with the low precision param. Register them @@ -1183,9 +1184,7 @@ class DistributedDataParallel(Module, Joinable): self.__dict__.setdefault("require_backward_grad_sync", True) parameters, expect_sparse_gradient = self._build_params_for_reducer() # In debug mode, build a mapping of parameter index -> parameter. - param_to_name_mapping = self._build_debug_param_to_name_mapping( - parameters - ) + param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters) # Builds reducer. self._ddp_init_helper( parameters, @@ -1211,8 +1210,7 @@ class DistributedDataParallel(Module, Joinable): # parameters through _former_parameters. for param_name, param in module.named_parameters(recurse=False) if param.requires_grad - and f"{module_name}.{param_name}" - not in self.parameters_to_ignore + and f"{module_name}.{param_name}" not in self.parameters_to_ignore ] ] @@ -1238,8 +1236,7 @@ class DistributedDataParallel(Module, Joinable): # Build list of booleans indicating whether or not to expect sparse # gradients for the corresponding parameters. expect_sparse_gradient = [ - produces_sparse_gradient(module) - for module, _ in modules_and_parameters + produces_sparse_gradient(module) for module, _ in modules_and_parameters ] self._assign_modules_buffers() @@ -1265,17 +1262,14 @@ class DistributedDataParallel(Module, Joinable): ] # Dict[str, tensor] representing module buffers not ignored by DDP. self.named_module_buffers = { - buffer_name: buffer - for (buffer, buffer_name) in named_module_buffers + buffer_name: buffer for (buffer, buffer_name) in named_module_buffers } def _build_debug_param_to_name_mapping(self, parameters): if dist.get_debug_level() == dist.DebugLevel.OFF: return {} - param_to_param_index = { - parameters[i]: i for i in range(len(parameters)) - } + param_to_param_index = {parameters[i]: i for i in range(len(parameters))} param_set = set(parameters) param_index_to_param_fqn = {} for module_name, module in self.module.named_modules(): @@ -1400,9 +1394,7 @@ class DistributedDataParallel(Module, Joinable): ) args, kwargs = inputs[0], kwargs[0] # type: ignore[index] # Cast inputs to reduced precision if needed. - if ( - self.mixed_precision is not None - ): + if self.mixed_precision is not None: args, kwargs = _cast_forward_inputs( self.mixed_precision.param_dtype, *args, @@ -1413,9 +1405,7 @@ class DistributedDataParallel(Module, Joinable): else: # Cast inputs to reduced precision if needed. # TODO (rohan-varma) test this codepath. - if ( - self.mixed_precision is not None - ): + if self.mixed_precision is not None: inputs, kwargs = _cast_forward_inputs( self.mixed_precision.param_dtype, *inputs, @@ -1446,9 +1436,7 @@ class DistributedDataParallel(Module, Joinable): self._delay_grad_buffer.zero_() def forward(self, *inputs, **kwargs): - with torch.autograd.profiler.record_function( - "DistributedDataParallel.forward" - ): + with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): if self._delay_all_reduce_all_params: output = self.module.forward(*inputs, **kwargs) self._clear_grad_buffer() @@ -1475,9 +1463,7 @@ class DistributedDataParallel(Module, Joinable): # during forward computation. # This should be called only once during whole training period. if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): - logger.info( - "Reducer buckets have been rebuilt in this iteration." - ) + logger.info("Reducer buckets have been rebuilt in this iteration.") self._has_rebuilt_buckets = True # sync params according to location (before/after forward) user @@ -1487,9 +1473,7 @@ class DistributedDataParallel(Module, Joinable): if self._join_config.enable: # Notify joined ranks whether they should sync in backwards pass or not. - self._check_global_requires_backward_grad_sync( - is_joined_rank=False - ) + self._check_global_requires_backward_grad_sync(is_joined_rank=False) output = self._run_ddp_forward(*inputs, **kwargs) @@ -1507,9 +1491,7 @@ class DistributedDataParallel(Module, Joinable): # unused parameters. Only if `find_unused_parameters` is set. if self.find_unused_parameters and not self.static_graph: # Do not need to populate this for static graph. - self.reducer.prepare_for_backward( - list(_find_tensors(output)) - ) + self.reducer.prepare_for_backward(list(_find_tensors(output))) else: self.reducer.prepare_for_backward([]) else: @@ -1593,9 +1575,7 @@ class DistributedDataParallel(Module, Joinable): # the models have buffers that should be synchronized in the forward pass. def _check_and_sync_module_buffers(self): if self._check_sync_bufs_pre_fwd(): - authoritative_rank = self._find_common_rank( - self._distributed_rank, False - ) + authoritative_rank = self._find_common_rank(self._distributed_rank, False) self._sync_module_buffers(authoritative_rank) # When running in join model, agrees upon a common rank and broadcast model @@ -1777,9 +1757,7 @@ class DistributedDataParallel(Module, Joinable): cases for possibly better results. Default is ``True``. """ - divide_by_initial_world_size = kwargs.get( - "divide_by_initial_world_size", True - ) + divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True) return _DDPJoinHook( self, divide_by_initial_world_size=divide_by_initial_world_size ) @@ -1946,9 +1924,7 @@ class DistributedDataParallel(Module, Joinable): self.logger._set_comm_hook_name(str(comm_hook_type)) dist._register_builtin_comm_hook(self.reducer, comm_hook_type) - def _register_fused_optim( - self, optim: Type, *args, optim_params=None, **kwargs - ): + def _register_fused_optim(self, optim: Type, *args, optim_params=None, **kwargs): r""" Registers an optimizer with DDP such that the optimization for a parameter will run immediately when that parameter's gradient is @@ -2008,13 +1984,9 @@ class DistributedDataParallel(Module, Joinable): """ # Note: importing in function, otherwise this will cause a circular # import as optimizer_overlap module needs to import DistributedDataParallel. - from torch.distributed.algorithms._optimizer_overlap import ( - _as_overlapped_optim, - ) + from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim - overlapped_optim = _as_overlapped_optim( - optim, optim_params, *args, **kwargs - ) + overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs) try: overlapped_optim.register_ddp(self) except NotImplementedError as e: @@ -2088,9 +2060,7 @@ class DistributedDataParallel(Module, Joinable): def _sync_module_buffers(self, authoritative_rank): if not hasattr(self, "buffer_hook"): - self._default_broadcast_coalesced( - authoritative_rank=authoritative_rank - ) + self._default_broadcast_coalesced(authoritative_rank=authoritative_rank) else: hook = self.buffer_hook.buffer_comm_hook state = self.buffer_hook.buffer_comm_hook_state @@ -2111,9 +2081,7 @@ class DistributedDataParallel(Module, Joinable): if bucket_size is None: bucket_size = self.broadcast_bucket_size - self._distributed_broadcast_coalesced( - bufs, bucket_size, authoritative_rank - ) + self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank) def _passing_sync_batchnorm_handle(self, module): for layer in module.modules(): @@ -2126,9 +2094,7 @@ class DistributedDataParallel(Module, Joinable): def _check_comm_hook(self, hook): if not callable(hook): - self._log_and_throw( - TypeError, "Communication hook must be callable." - ) + self._log_and_throw(TypeError, "Communication hook must be callable.") sig = inspect.signature(hook) if ( @@ -2173,8 +2139,10 @@ class DistributedDataParallel(Module, Joinable): @staticmethod def _get_data_parallel_params(module, named_params=False): - for param in (module.parameters() if not named_params else module.named_parameters()): - if not hasattr(param, '_ddp_ignored'): + for param in ( + module.parameters() if not named_params else module.named_parameters() + ): + if not hasattr(param, "_ddp_ignored"): yield param @staticmethod