mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[1/n] Consolidate replicate
and DDP
: setup ufmt for distributed.py
(#96597)
As we already enabled ufmt for composable APIs in https://github.com/pytorch/pytorch/pull/90873, it seems a good idea to enable ufmt for other distributed APIs as well. This change setup ufmt for DDP. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96597 Approved by: https://github.com/rohan-varma
This commit is contained in:
committed by
PyTorch MergeBot
parent
24ce3a7c34
commit
13538c88b3
@ -862,6 +862,7 @@ include_patterns = [
|
||||
'test/test_value_ranges.py',
|
||||
'torch/utils/_sympy/interp.py',
|
||||
'torch/utils/_sympy/reference.py',
|
||||
'torch/nn/parallel/distributed.py',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
|
@ -1,6 +1,5 @@
|
||||
import copy
|
||||
import functools
|
||||
from collections import defaultdict, deque
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
@ -8,33 +7,30 @@ import os
|
||||
import sys
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import defaultdict, deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, Any, Type, Tuple, Optional, List
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.autograd import Function, Variable
|
||||
from torch.distributed.algorithms.join import (
|
||||
Join,
|
||||
Joinable,
|
||||
JoinHook,
|
||||
)
|
||||
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
|
||||
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
|
||||
RPC_AVAILABLE = False
|
||||
if dist.is_available():
|
||||
from torch.distributed.distributed_c10d import _get_default_group, ReduceOp
|
||||
from torch.distributed.utils import (
|
||||
_verify_param_shape_across_processes,
|
||||
_alloc_storage,
|
||||
_apply_to_tensors,
|
||||
_free_storage,
|
||||
_sync_module_states,
|
||||
_to_kwargs,
|
||||
_apply_to_tensors,
|
||||
_alloc_storage,
|
||||
_free_storage,
|
||||
_verify_param_shape_across_processes,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ReduceOp, _get_default_group
|
||||
if torch.distributed.rpc.is_available():
|
||||
RPC_AVAILABLE = True
|
||||
from torch.distributed.rpc import RRef
|
||||
@ -48,6 +44,7 @@ __all__ = ["DistributedDataParallel"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _MixedPrecision:
|
||||
"""
|
||||
@ -90,26 +87,28 @@ class _MixedPrecision:
|
||||
# in full precision. For DDP, this can be implemented by not performing the
|
||||
# parameter cast for BN and LN units.
|
||||
|
||||
|
||||
def _cast_buffers(mixed_precision_config, root_module):
|
||||
"""
|
||||
Casts buffers to the given ``buffer_dtype``.
|
||||
"""
|
||||
for buf in root_module.buffers():
|
||||
if hasattr(buf, '_ddp_ignored') and buf._ddp_ignored:
|
||||
if hasattr(buf, "_ddp_ignored") and buf._ddp_ignored:
|
||||
continue
|
||||
|
||||
buf.data = buf.to(dtype=mixed_precision_config.buffer_dtype)
|
||||
|
||||
|
||||
def _setup_mixed_precision_params(mixed_precision_config, root_module):
|
||||
"""
|
||||
Creates and frees storage for the mixed precision parameters.
|
||||
"""
|
||||
for param in root_module.parameters():
|
||||
# Do not setup mixed precision for DDP ignored parameters.
|
||||
if hasattr(param, '_ddp_ignored') and param._ddp_ignored:
|
||||
if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
|
||||
continue
|
||||
|
||||
if not hasattr(param, '_mp_param'):
|
||||
if not hasattr(param, "_mp_param"):
|
||||
param._mp_param = torch.zeros_like(
|
||||
param,
|
||||
device=param.device,
|
||||
@ -121,6 +120,7 @@ def _setup_mixed_precision_params(mixed_precision_config, root_module):
|
||||
# back to at the end of forward / backward.
|
||||
param._fp_param = param.data
|
||||
|
||||
|
||||
def _cast_forward_inputs(
|
||||
input_dtype: Optional[torch.dtype],
|
||||
*args: Any,
|
||||
@ -130,15 +130,14 @@ def _cast_forward_inputs(
|
||||
Casts input args and kwargs to the given input_dtype. Note that only
|
||||
floating point tensors are cast.
|
||||
"""
|
||||
|
||||
def cast_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
if not torch.is_floating_point(x) or x.dtype == input_dtype:
|
||||
return x
|
||||
return x.to(input_dtype)
|
||||
|
||||
return (
|
||||
_apply_to_tensors(cast_fn, args),
|
||||
_apply_to_tensors(cast_fn, kwargs)
|
||||
)
|
||||
return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))
|
||||
|
||||
|
||||
def _tree_flatten_with_rref(output):
|
||||
output_is_rref = RPC_AVAILABLE and isinstance(output, RRef)
|
||||
@ -265,8 +264,7 @@ class _DDPSink(Function):
|
||||
ctx.reducer = reducer
|
||||
ctx.state_dict = state_dict
|
||||
ret = tuple(
|
||||
inp.clone() if isinstance(inp, torch.Tensor) else inp
|
||||
for inp in inputs
|
||||
inp.clone() if isinstance(inp, torch.Tensor) else inp for inp in inputs
|
||||
)
|
||||
return ret
|
||||
|
||||
@ -274,10 +272,7 @@ class _DDPSink(Function):
|
||||
def backward(ctx, *grad_outputs):
|
||||
# Enqueue delay allreduce for static graph training on the first
|
||||
# iteration.
|
||||
if (
|
||||
ctx.state_dict["static_graph"]
|
||||
and ctx.state_dict["num_iterations"] == 1
|
||||
):
|
||||
if ctx.state_dict["static_graph"] and ctx.state_dict["num_iterations"] == 1:
|
||||
Variable._execution_engine.queue_callback( # type: ignore[call-arg,misc]
|
||||
ctx.reducer._delay_all_reduce
|
||||
)
|
||||
@ -316,9 +311,7 @@ class _DDPJoinHook(JoinHook):
|
||||
ddp._check_and_sync_module_buffers()
|
||||
|
||||
# Check if need to sync in the backward pass
|
||||
work = ddp._check_global_requires_backward_grad_sync(
|
||||
is_joined_rank=True
|
||||
)
|
||||
work = ddp._check_global_requires_backward_grad_sync(is_joined_rank=True)
|
||||
work.wait()
|
||||
should_sync_backwards = work.result()[0].item() != 0
|
||||
# Forward parameter sync is disabled in the next iteration if we
|
||||
@ -666,11 +659,13 @@ class DistributedDataParallel(Module, Joinable):
|
||||
super().__init__()
|
||||
Joinable.__init__(self)
|
||||
self.logger = None
|
||||
if bool(delay_all_reduce_named_params is not None) != bool(param_to_hook_all_reduce is not None):
|
||||
if bool(delay_all_reduce_named_params is not None) != bool(
|
||||
param_to_hook_all_reduce is not None
|
||||
):
|
||||
self._log_and_throw(
|
||||
ValueError,
|
||||
"delay_all_reduce_named_params and param_to_hook_all_reduce "
|
||||
"need to be set at the same time."
|
||||
"need to be set at the same time.",
|
||||
)
|
||||
|
||||
self._delay_all_reduce_params = []
|
||||
@ -683,7 +678,11 @@ class DistributedDataParallel(Module, Joinable):
|
||||
self.parameters_to_ignore.add(name)
|
||||
self._delay_all_reduce_params.append(param)
|
||||
|
||||
self._module_parameters = [p for n, p in module.named_parameters() if n not in self.parameters_to_ignore]
|
||||
self._module_parameters = [
|
||||
p
|
||||
for n, p in module.named_parameters()
|
||||
if n not in self.parameters_to_ignore
|
||||
]
|
||||
if not any((p.requires_grad for p in self._module_parameters)):
|
||||
if len(self._delay_all_reduce_params):
|
||||
logger.info("Delay the AllReduce of all parameters.")
|
||||
@ -700,8 +699,12 @@ class DistributedDataParallel(Module, Joinable):
|
||||
"device_ids can only be None or contain a single element.",
|
||||
)
|
||||
|
||||
self.is_multi_device_module = len({p.device for p in self._module_parameters}) > 1
|
||||
distinct_device_types = {p.device.type for p in self._module_parameters if p.device is not None}
|
||||
self.is_multi_device_module = (
|
||||
len({p.device for p in self._module_parameters}) > 1
|
||||
)
|
||||
distinct_device_types = {
|
||||
p.device.type for p in self._module_parameters if p.device is not None
|
||||
}
|
||||
if len(distinct_device_types) != 1:
|
||||
self._log_and_throw(
|
||||
ValueError,
|
||||
@ -813,9 +816,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
params_and_buffers_to_ignore=self.parameters_to_ignore,
|
||||
)
|
||||
# In debug mode, build a mapping of parameter index -> parameter.
|
||||
param_to_name_mapping = self._build_debug_param_to_name_mapping(
|
||||
parameters
|
||||
)
|
||||
param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
|
||||
|
||||
# Builds reducer.
|
||||
self._ddp_init_helper(
|
||||
@ -839,15 +840,19 @@ class DistributedDataParallel(Module, Joinable):
|
||||
# before running computation.
|
||||
for module in self.module.modules():
|
||||
module.register_forward_pre_hook(
|
||||
self._module_wait_for_copy_hook, prepend=False, with_kwargs=True,
|
||||
self._module_wait_for_copy_hook,
|
||||
prepend=False,
|
||||
with_kwargs=True,
|
||||
)
|
||||
# Set up callbacks in backward to upcast and use full precision
|
||||
# params. TODO (rohan-varma): Make this compose with general
|
||||
# comm hooks and apply_optimizer_in_backward. Importing inline to
|
||||
# avoid circular import issue.
|
||||
from torch.distributed.algorithms.ddp_comm_hooks.mixed_precision_hooks import (
|
||||
_reducer_allreduce_and_upcast_hook, _AllreduceUpcastHookState
|
||||
_AllreduceUpcastHookState,
|
||||
_reducer_allreduce_and_upcast_hook,
|
||||
)
|
||||
|
||||
upcast_hook_state = _AllreduceUpcastHookState(
|
||||
ddp_weakref=weakref.ref(self),
|
||||
upcast_stream=torch.cuda.Stream(),
|
||||
@ -892,7 +897,9 @@ class DistributedDataParallel(Module, Joinable):
|
||||
|
||||
def _delayed_all_reduce(grad):
|
||||
self._delay_grad_buffer.div_(world_size) # type: ignore[union-attr]
|
||||
_ = dist.all_reduce(self._delay_grad_buffer, group=process_group, async_op=True)
|
||||
_ = dist.all_reduce(
|
||||
self._delay_grad_buffer, group=process_group, async_op=True
|
||||
)
|
||||
return grad
|
||||
|
||||
param_to_hook_all_reduce.register_hook(_delayed_all_reduce)
|
||||
@ -931,13 +938,13 @@ class DistributedDataParallel(Module, Joinable):
|
||||
# ping https://github.com/pytorch/pytorch/issues/90052.
|
||||
# NOTE: we use self._module_parameters instead of .parameters() since
|
||||
# the former excludes ignored (non-DDP managed) parameters.
|
||||
if any(
|
||||
hasattr(p, '_in_backward_optimizers') for p in self._module_parameters
|
||||
):
|
||||
if any(hasattr(p, "_in_backward_optimizers") for p in self._module_parameters):
|
||||
# Remove hooks that apply_optim_in_backward had registered because
|
||||
# DDP customizes how optimizer is overlapped with backward due to
|
||||
# the allreduce.
|
||||
param_to_handle_map = dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map
|
||||
param_to_handle_map = (
|
||||
dist.optim.apply_optimizer_in_backward.param_to_optim_hook_handle_map
|
||||
)
|
||||
for p in self._module_parameters:
|
||||
for handle in param_to_handle_map.get(p, []):
|
||||
handle.remove()
|
||||
@ -947,8 +954,9 @@ class DistributedDataParallel(Module, Joinable):
|
||||
# Note: importing in function, otherwise this will cause a circular
|
||||
# import.
|
||||
from torch.distributed.algorithms.ddp_comm_hooks.optimizer_overlap_hooks import (
|
||||
_apply_optim_in_backward_hook
|
||||
_apply_optim_in_backward_hook,
|
||||
)
|
||||
|
||||
self.register_comm_hook(
|
||||
(reducer_weakref, self.process_group),
|
||||
_apply_optim_in_backward_hook(
|
||||
@ -981,11 +989,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
"""
|
||||
self.reducer._autograd_hook(idx) # type: ignore[attr-defined]
|
||||
|
||||
def _root_copy_hook(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
def _root_copy_hook(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
When training with DDP mixed precision, this root pre-forward hook kicks
|
||||
off low precision copies on a separate stream and creates respective
|
||||
@ -999,7 +1003,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
for submodule in self.module.modules():
|
||||
for param in submodule.parameters(recurse=False):
|
||||
# Do not cast DDP ignored parameters.
|
||||
if hasattr(param, '_ddp_ignored') and param._ddp_ignored:
|
||||
if hasattr(param, "_ddp_ignored") and param._ddp_ignored:
|
||||
continue
|
||||
_alloc_storage(param._mp_param, param.size())
|
||||
# copy() implicitly casts to low precision
|
||||
@ -1042,10 +1046,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
event.wait(stream=torch.cuda.current_stream())
|
||||
for p in module.parameters(recurse=False):
|
||||
# Don't register hooks if param does not require grad
|
||||
if (
|
||||
not p.requires_grad
|
||||
or (hasattr(p, '_ddp_ignored') and p._ddp_ignored)
|
||||
):
|
||||
if not p.requires_grad or (hasattr(p, "_ddp_ignored") and p._ddp_ignored):
|
||||
continue
|
||||
# We need to register autograd hook here instead of DDP's ctor
|
||||
# since we're working with the low precision param. Register them
|
||||
@ -1183,9 +1184,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
self.__dict__.setdefault("require_backward_grad_sync", True)
|
||||
parameters, expect_sparse_gradient = self._build_params_for_reducer()
|
||||
# In debug mode, build a mapping of parameter index -> parameter.
|
||||
param_to_name_mapping = self._build_debug_param_to_name_mapping(
|
||||
parameters
|
||||
)
|
||||
param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
|
||||
# Builds reducer.
|
||||
self._ddp_init_helper(
|
||||
parameters,
|
||||
@ -1211,8 +1210,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
# parameters through _former_parameters.
|
||||
for param_name, param in module.named_parameters(recurse=False)
|
||||
if param.requires_grad
|
||||
and f"{module_name}.{param_name}"
|
||||
not in self.parameters_to_ignore
|
||||
and f"{module_name}.{param_name}" not in self.parameters_to_ignore
|
||||
]
|
||||
]
|
||||
|
||||
@ -1238,8 +1236,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
# Build list of booleans indicating whether or not to expect sparse
|
||||
# gradients for the corresponding parameters.
|
||||
expect_sparse_gradient = [
|
||||
produces_sparse_gradient(module)
|
||||
for module, _ in modules_and_parameters
|
||||
produces_sparse_gradient(module) for module, _ in modules_and_parameters
|
||||
]
|
||||
|
||||
self._assign_modules_buffers()
|
||||
@ -1265,17 +1262,14 @@ class DistributedDataParallel(Module, Joinable):
|
||||
]
|
||||
# Dict[str, tensor] representing module buffers not ignored by DDP.
|
||||
self.named_module_buffers = {
|
||||
buffer_name: buffer
|
||||
for (buffer, buffer_name) in named_module_buffers
|
||||
buffer_name: buffer for (buffer, buffer_name) in named_module_buffers
|
||||
}
|
||||
|
||||
def _build_debug_param_to_name_mapping(self, parameters):
|
||||
if dist.get_debug_level() == dist.DebugLevel.OFF:
|
||||
return {}
|
||||
|
||||
param_to_param_index = {
|
||||
parameters[i]: i for i in range(len(parameters))
|
||||
}
|
||||
param_to_param_index = {parameters[i]: i for i in range(len(parameters))}
|
||||
param_set = set(parameters)
|
||||
param_index_to_param_fqn = {}
|
||||
for module_name, module in self.module.named_modules():
|
||||
@ -1400,9 +1394,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
)
|
||||
args, kwargs = inputs[0], kwargs[0] # type: ignore[index]
|
||||
# Cast inputs to reduced precision if needed.
|
||||
if (
|
||||
self.mixed_precision is not None
|
||||
):
|
||||
if self.mixed_precision is not None:
|
||||
args, kwargs = _cast_forward_inputs(
|
||||
self.mixed_precision.param_dtype,
|
||||
*args,
|
||||
@ -1413,9 +1405,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
else:
|
||||
# Cast inputs to reduced precision if needed.
|
||||
# TODO (rohan-varma) test this codepath.
|
||||
if (
|
||||
self.mixed_precision is not None
|
||||
):
|
||||
if self.mixed_precision is not None:
|
||||
inputs, kwargs = _cast_forward_inputs(
|
||||
self.mixed_precision.param_dtype,
|
||||
*inputs,
|
||||
@ -1446,9 +1436,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
self._delay_grad_buffer.zero_()
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
with torch.autograd.profiler.record_function(
|
||||
"DistributedDataParallel.forward"
|
||||
):
|
||||
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
|
||||
if self._delay_all_reduce_all_params:
|
||||
output = self.module.forward(*inputs, **kwargs)
|
||||
self._clear_grad_buffer()
|
||||
@ -1475,9 +1463,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
# during forward computation.
|
||||
# This should be called only once during whole training period.
|
||||
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
|
||||
logger.info(
|
||||
"Reducer buckets have been rebuilt in this iteration."
|
||||
)
|
||||
logger.info("Reducer buckets have been rebuilt in this iteration.")
|
||||
self._has_rebuilt_buckets = True
|
||||
|
||||
# sync params according to location (before/after forward) user
|
||||
@ -1487,9 +1473,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
|
||||
if self._join_config.enable:
|
||||
# Notify joined ranks whether they should sync in backwards pass or not.
|
||||
self._check_global_requires_backward_grad_sync(
|
||||
is_joined_rank=False
|
||||
)
|
||||
self._check_global_requires_backward_grad_sync(is_joined_rank=False)
|
||||
|
||||
output = self._run_ddp_forward(*inputs, **kwargs)
|
||||
|
||||
@ -1507,9 +1491,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
# unused parameters. Only if `find_unused_parameters` is set.
|
||||
if self.find_unused_parameters and not self.static_graph:
|
||||
# Do not need to populate this for static graph.
|
||||
self.reducer.prepare_for_backward(
|
||||
list(_find_tensors(output))
|
||||
)
|
||||
self.reducer.prepare_for_backward(list(_find_tensors(output)))
|
||||
else:
|
||||
self.reducer.prepare_for_backward([])
|
||||
else:
|
||||
@ -1593,9 +1575,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
# the models have buffers that should be synchronized in the forward pass.
|
||||
def _check_and_sync_module_buffers(self):
|
||||
if self._check_sync_bufs_pre_fwd():
|
||||
authoritative_rank = self._find_common_rank(
|
||||
self._distributed_rank, False
|
||||
)
|
||||
authoritative_rank = self._find_common_rank(self._distributed_rank, False)
|
||||
self._sync_module_buffers(authoritative_rank)
|
||||
|
||||
# When running in join model, agrees upon a common rank and broadcast model
|
||||
@ -1777,9 +1757,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
cases for possibly better results.
|
||||
Default is ``True``.
|
||||
"""
|
||||
divide_by_initial_world_size = kwargs.get(
|
||||
"divide_by_initial_world_size", True
|
||||
)
|
||||
divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True)
|
||||
return _DDPJoinHook(
|
||||
self, divide_by_initial_world_size=divide_by_initial_world_size
|
||||
)
|
||||
@ -1946,9 +1924,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
self.logger._set_comm_hook_name(str(comm_hook_type))
|
||||
dist._register_builtin_comm_hook(self.reducer, comm_hook_type)
|
||||
|
||||
def _register_fused_optim(
|
||||
self, optim: Type, *args, optim_params=None, **kwargs
|
||||
):
|
||||
def _register_fused_optim(self, optim: Type, *args, optim_params=None, **kwargs):
|
||||
r"""
|
||||
Registers an optimizer with DDP such that the optimization for a
|
||||
parameter will run immediately when that parameter's gradient is
|
||||
@ -2008,13 +1984,9 @@ class DistributedDataParallel(Module, Joinable):
|
||||
"""
|
||||
# Note: importing in function, otherwise this will cause a circular
|
||||
# import as optimizer_overlap module needs to import DistributedDataParallel.
|
||||
from torch.distributed.algorithms._optimizer_overlap import (
|
||||
_as_overlapped_optim,
|
||||
)
|
||||
from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim
|
||||
|
||||
overlapped_optim = _as_overlapped_optim(
|
||||
optim, optim_params, *args, **kwargs
|
||||
)
|
||||
overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs)
|
||||
try:
|
||||
overlapped_optim.register_ddp(self)
|
||||
except NotImplementedError as e:
|
||||
@ -2088,9 +2060,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
|
||||
def _sync_module_buffers(self, authoritative_rank):
|
||||
if not hasattr(self, "buffer_hook"):
|
||||
self._default_broadcast_coalesced(
|
||||
authoritative_rank=authoritative_rank
|
||||
)
|
||||
self._default_broadcast_coalesced(authoritative_rank=authoritative_rank)
|
||||
else:
|
||||
hook = self.buffer_hook.buffer_comm_hook
|
||||
state = self.buffer_hook.buffer_comm_hook_state
|
||||
@ -2111,9 +2081,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
if bucket_size is None:
|
||||
bucket_size = self.broadcast_bucket_size
|
||||
|
||||
self._distributed_broadcast_coalesced(
|
||||
bufs, bucket_size, authoritative_rank
|
||||
)
|
||||
self._distributed_broadcast_coalesced(bufs, bucket_size, authoritative_rank)
|
||||
|
||||
def _passing_sync_batchnorm_handle(self, module):
|
||||
for layer in module.modules():
|
||||
@ -2126,9 +2094,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
|
||||
def _check_comm_hook(self, hook):
|
||||
if not callable(hook):
|
||||
self._log_and_throw(
|
||||
TypeError, "Communication hook must be callable."
|
||||
)
|
||||
self._log_and_throw(TypeError, "Communication hook must be callable.")
|
||||
|
||||
sig = inspect.signature(hook)
|
||||
if (
|
||||
@ -2173,8 +2139,10 @@ class DistributedDataParallel(Module, Joinable):
|
||||
|
||||
@staticmethod
|
||||
def _get_data_parallel_params(module, named_params=False):
|
||||
for param in (module.parameters() if not named_params else module.named_parameters()):
|
||||
if not hasattr(param, '_ddp_ignored'):
|
||||
for param in (
|
||||
module.parameters() if not named_params else module.named_parameters()
|
||||
):
|
||||
if not hasattr(param, "_ddp_ignored"):
|
||||
yield param
|
||||
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user