mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DDP] Use compiled_autograd to trace DDP backward allreduce (#110662)
**Summary** The reducer of `DistributedDataParallel` is implemented with C++ and it is not easy to trace the allreduce launched in the reducer. This PR modifies `DistributedDataParallel` to launch one allreduce per gradient when `compiled_autograd` is enabled. The changes allow us to use `compiled_autograd` to trace the allreduce and later be optimized (fused) in the Inductor. **Key Logic** 1. If `ddp_python_hook` is True, we assume `compiled_autograd` is used. `DistributedDataParallel` registers `compiled_accum_grad_hook` for all parameters. 2. In the first forward() call, if `DistributedDataParallel` is not compiled, all `compiled_accum_grad_hook` are deregistered. If `DistributedDataParallel` is compiled, all `compiled_accum_grad_hook` will be compiled by `compiled_autograd`. 3. `compiled_accum_grad_hook` launches an allreduce to reduce the gradient of the parameter. **Bucketing** The compiled backward is slow because there is no bucketing for the allreduces. We rely on Inductor to bucket the allreduces. The bucketing is done in a separate PR. Differential Revision: [D49428482](https://our.internmc.facebook.com/intern/diff/D49428482/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/110662 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
113506d2d4
commit
1d2382f141
211
test/distributed/_composable/test_replicate_with_compiler.py
Normal file
211
test/distributed/_composable/test_replicate_with_compiler.py
Normal file
@ -0,0 +1,211 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import _inductor as inductor, nn
|
||||
from torch._dynamo import compiled_autograd
|
||||
from torch.distributed._composable.replicate import replicate
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import (
|
||||
default_hooks as ddp_default_hooks,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
skip_if_lt_x_gpu,
|
||||
skip_if_rocm,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.utils._triton import has_triton
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
DIM = 2000
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, checkpoint=False):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(DIM, DIM)
|
||||
self.fc2 = nn.Linear(DIM, DIM)
|
||||
self.fc3 = nn.Linear(DIM, DIM)
|
||||
self.fc4 = nn.Linear(DIM, DIM)
|
||||
self.use_checkpoint = checkpoint
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_checkpoint:
|
||||
_fc1 = checkpoint(self.fc1, x, use_reentrant=False)
|
||||
else:
|
||||
_fc1 = self.fc1(x)
|
||||
return self.fc4(self.fc3(self.fc2(_fc1)))
|
||||
|
||||
|
||||
def compiler_fn(no_inductor):
|
||||
def _compiler_fn(gm):
|
||||
def inner_compiler(gm_, example_inputs_):
|
||||
if no_inductor:
|
||||
return gm_
|
||||
else:
|
||||
return inductor.compile(gm_, example_inputs_)
|
||||
|
||||
gm = torch.compile(gm, fullgraph=True, backend=inner_compiler)
|
||||
return gm
|
||||
|
||||
return _compiler_fn
|
||||
|
||||
|
||||
class ReplicateTest(MultiProcessTestCase):
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _test_compile(
|
||||
self,
|
||||
*,
|
||||
use_gpu: bool,
|
||||
no_sync: bool,
|
||||
setup_func: Optional[Callable] = None,
|
||||
no_inductor: bool = False,
|
||||
no_compile_forward: bool = False,
|
||||
):
|
||||
backend = "nccl" if use_gpu else "gloo"
|
||||
dist.init_process_group(
|
||||
backend=backend,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
store=dist.FileStore(self.file_name, self.world_size),
|
||||
)
|
||||
if use_gpu:
|
||||
torch.cuda.set_device(f"cuda:{self.rank}")
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
torch._dynamo.config.optimize_ddp = (
|
||||
"python_reducer_without_compiled_forward"
|
||||
if no_compile_forward
|
||||
else "python_reducer"
|
||||
)
|
||||
torch.manual_seed(123)
|
||||
model = Net().to(device)
|
||||
input = torch.randn([1, DIM], device=device)
|
||||
|
||||
compiled_model = torch.compile(replicate(deepcopy(model)), fullgraph=True)
|
||||
compiled_optim = torch.optim.Adam(compiled_model.parameters())
|
||||
model = replicate(model)
|
||||
optim = torch.optim.Adam(model.parameters())
|
||||
|
||||
if setup_func:
|
||||
setup_func(model, compiled_model)
|
||||
|
||||
# Run multiple iterations so that we could test no_sync
|
||||
for i in range(2):
|
||||
# Setting a different random seed so that if the allreduces are not
|
||||
# executed correctly, the gradients won't be correct compared to the
|
||||
# eager DDP.
|
||||
torch.manual_seed(123 + self.rank + i)
|
||||
input = torch.randn([1, DIM], device=device)
|
||||
|
||||
if no_sync and i % 2 == 0:
|
||||
context = replicate.state(model)._ddp.no_sync()
|
||||
else:
|
||||
context = contextlib.nullcontext()
|
||||
with context:
|
||||
loss = model(input).sum()
|
||||
loss.backward()
|
||||
|
||||
compiled_m = getattr(compiled_model, "_orig_mod", compiled_model)
|
||||
if no_sync and i % 2 == 0:
|
||||
context = replicate.state(compiled_m)._ddp.no_sync()
|
||||
else:
|
||||
context = contextlib.nullcontext()
|
||||
with context:
|
||||
with compiled_autograd.enable(compiler_fn(no_inductor)):
|
||||
compiled_loss = compiled_model(input).sum()
|
||||
compiled_loss.backward()
|
||||
|
||||
if not no_sync or i % 2 == 1:
|
||||
for p1, p2 in zip(model.parameters(), compiled_model.parameters()):
|
||||
self.assertEqual(p1.grad, p2.grad)
|
||||
compiled_optim.step()
|
||||
# Right now we have to use `set_to_none=False`, otherwise
|
||||
# the backward will be recompiled every iteration.
|
||||
# With `set_to_none=False`, it will only be recompiled once.
|
||||
# https://github.com/pytorch/pytorch/issues/118435
|
||||
compiled_optim.zero_grad(set_to_none=False)
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
|
||||
self.assertEqual(tuple(model.parameters()), tuple(compiled_model.parameters()))
|
||||
|
||||
def test_compile_cpu(self):
|
||||
self._test_compile(use_gpu=False, no_sync=False)
|
||||
|
||||
def test_compile_cpu_no_sync(self):
|
||||
self._test_compile(use_gpu=False, no_sync=True)
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_rocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_compile_gpu(self):
|
||||
self._test_compile(use_gpu=True, no_sync=False)
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_rocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_compile_bf16(self):
|
||||
def setup(model, compiled_model) -> None:
|
||||
replicate.state(model)._ddp.register_comm_hook(
|
||||
None, ddp_default_hooks.bf16_compress_hook
|
||||
)
|
||||
compiled_m = compiled_model._orig_mod
|
||||
replicate.state(compiled_m)._ddp.register_comm_hook(
|
||||
None, ddp_default_hooks.bf16_compress_hook
|
||||
)
|
||||
|
||||
self._test_compile(
|
||||
use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_rocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_compile_fp16(self):
|
||||
def setup(model, compiled_model) -> None:
|
||||
replicate.state(model)._ddp.register_comm_hook(
|
||||
None, ddp_default_hooks.fp16_compress_hook
|
||||
)
|
||||
compiled_m = compiled_model._orig_mod
|
||||
replicate.state(compiled_m)._ddp.register_comm_hook(
|
||||
None, ddp_default_hooks.fp16_compress_hook
|
||||
)
|
||||
|
||||
# TODO: figure out why we need to disable Inductor to avoid test errors.
|
||||
self._test_compile(
|
||||
use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
@skip_if_rocm
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_compile_backward_only(self):
|
||||
self._test_compile(use_gpu=True, no_sync=False, no_compile_forward=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -5,7 +5,7 @@ import re
|
||||
import sys
|
||||
import tempfile
|
||||
from os.path import abspath, dirname
|
||||
from typing import Any, Dict, Optional, Set, Type, TYPE_CHECKING
|
||||
from typing import Any, Dict, Optional, Set, Type, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -220,12 +220,50 @@ force_unspec_int_unbacked_size_like_on_torchrec_kjt = False
|
||||
# false_fn produces code with identical guards.
|
||||
enforce_cond_guards_match = True
|
||||
|
||||
# Automatically split model graph into pieces to match DDP bucket sizes
|
||||
# to allow DDP comm/compute overlap. Disable to allow DDP models to
|
||||
# run without graph-breaks, but also without comm/compute overlap.
|
||||
# set TORCH_LOGS env to include any of 'dynamo', 'distributed', or
|
||||
# 'dist_ddp' for more info about optimize_ddp behavior.
|
||||
optimize_ddp = True
|
||||
# Specify how to optimize a compiiled DDP module. The flag accepts a bollean
|
||||
# value or a string. There are 4 modes.
|
||||
# 1. "ddp_optimizer" (or True): with "ddp_ptimizer", Dynamo will automatically
|
||||
# split model graph into pieces to match DDP bucket sizes to allow DDP
|
||||
# comm/compute overlap.
|
||||
# 2. "python_reducer" (experimental): this optimization requires the usage
|
||||
# of compiled_autograd. With "python_reducer", DDP will disable the C++ reducer
|
||||
# and use the Python reducer to allow compiled_autograd to trace the
|
||||
# communication and allow comm/compute overlap without graph-breaks.
|
||||
# 3. "python_reducer_without_compiled_forward" (experimental): this mode is
|
||||
# similar to "python_reducer". One should only use this optimization mode
|
||||
# when compiled_autograd is used but the DDP module is not compiled.
|
||||
# 4. "no_optimization" (or False): Dynamo won't split the model graph, nor
|
||||
# will Python reducer be used. With this mode, there will be no graph-breaks
|
||||
# and the original DDP C++ reducer will be used. There will no comm/compute
|
||||
# overlap. This mode CANNOT be used with compiled_autograd.
|
||||
# Note that to avoid breaking the existing usage, mode 1 and mode 4 can be
|
||||
# specified with a boolean value. True is using ddp_optimizer and False is
|
||||
# no optimization.
|
||||
optimize_ddp: Union[bool, str] = True
|
||||
|
||||
_ddp_optimization_mode = [
|
||||
"ddp_optimizer",
|
||||
"python_reducer", # experimental mode
|
||||
"python_reducer_without_compiled_forward", # experimental mode
|
||||
"no_optimization",
|
||||
]
|
||||
|
||||
|
||||
def _get_optimize_ddp_mode():
|
||||
m = sys.modules[__name__]
|
||||
if isinstance(m.optimize_ddp, bool):
|
||||
if m.optimize_ddp:
|
||||
mode = "ddp_optimizer"
|
||||
else:
|
||||
mode = "no_optimization"
|
||||
elif isinstance(m.optimize_ddp, str):
|
||||
mode = m.optimize_ddp
|
||||
else:
|
||||
raise ValueError(f"Invalid type, {type(optimize_ddp)=}")
|
||||
|
||||
assert mode in m._ddp_optimization_mode, f"Invalid mode {mode=}"
|
||||
return mode
|
||||
|
||||
|
||||
# If True, delays DDPOptimizer submodule compilation to 1st run of the model,
|
||||
# so that real tensor strides are used in all submodules
|
||||
|
||||
@ -881,7 +881,7 @@ def catch_errors_wrapper(callback, hooks: Hooks):
|
||||
if frame.f_code.co_filename == "<string>" and frame.f_code.co_name == "__new__":
|
||||
# nametuple constructor
|
||||
return None
|
||||
if config.optimize_ddp:
|
||||
if config._get_optimize_ddp_mode() == "ddp_optimizer":
|
||||
ddp_module = DistributedDataParallel._get_active_ddp_module()
|
||||
if ddp_module:
|
||||
with compile_lock:
|
||||
|
||||
@ -188,6 +188,9 @@ if torch.distributed.is_available():
|
||||
"torch.distributed.tensor.parallel._data_parallel_utils",
|
||||
"torch.distributed.tensor.parallel._utils",
|
||||
"torch.distributed.tensor.parallel.style",
|
||||
# we have to add replicate to LEGACY_MOD_INLINELIST to ensure
|
||||
# the forward_hook won't be ignored.
|
||||
"torch.distributed._composable.replicate",
|
||||
}
|
||||
|
||||
|
||||
@ -229,6 +232,7 @@ MOD_INLINELIST = {
|
||||
if torch.distributed.is_available():
|
||||
MOD_INLINELIST.add("torch.distributed")
|
||||
MOD_INLINELIST.add("torch.distributed._functional_collectives")
|
||||
MOD_INLINELIST.add("torch.distributed._composable.replicate")
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
|
||||
@ -9,15 +9,18 @@ namespace utils {
|
||||
// Turns lambda into a torch::autograd::FunctionPostHook.
|
||||
class LambdaPostHook : public torch::autograd::FunctionPostHook {
|
||||
using variable_list = std::vector<torch::autograd::Variable>;
|
||||
using fn_type =
|
||||
std::function<variable_list(const variable_list&, const variable_list&)>;
|
||||
using compiled_fn_type = std::function<void(CompiledNodeArgs&)>;
|
||||
|
||||
public:
|
||||
// The lambda function takes as arguments the outputs and inputs of the
|
||||
// autograd function and can modify the outputs of the autograd function by
|
||||
// returning a new output if needed.
|
||||
/* implicit */ LambdaPostHook(
|
||||
std::function<variable_list(const variable_list&, const variable_list&)>
|
||||
fn)
|
||||
: fn_(std::move(fn)) {}
|
||||
/* implicit */ LambdaPostHook(fn_type fn) : fn_(std::move(fn)) {}
|
||||
|
||||
LambdaPostHook(fn_type fn, compiled_fn_type compiled_fn)
|
||||
: fn_(std::move(fn)), compiled_fn_(std::move(compiled_fn)) {}
|
||||
|
||||
variable_list operator()(
|
||||
const variable_list& outputs,
|
||||
@ -25,8 +28,11 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
|
||||
return fn_(outputs, inputs);
|
||||
}
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) override {}
|
||||
|
||||
protected:
|
||||
std::function<variable_list(const variable_list&, const variable_list&)> fn_;
|
||||
compiled_fn_type compiled_fn_;
|
||||
};
|
||||
|
||||
} // namespace utils
|
||||
|
||||
@ -193,6 +193,9 @@ Reducer::Reducer(
|
||||
#endif
|
||||
this->autograd_hook(variable_index);
|
||||
return outputs;
|
||||
},
|
||||
[=](torch::autograd::CompiledNodeArgs& args) {
|
||||
// Make post_hook an noop if compiled_autograds is enabled.
|
||||
})),
|
||||
grad_accumulator);
|
||||
|
||||
|
||||
@ -1,9 +1,16 @@
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, cast, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
__all__ = ["allreduce_hook", "fp16_compress_hook", "bf16_compress_hook", "fp16_compress_wrapper", "bf16_compress_wrapper"]
|
||||
__all__ = [
|
||||
"allreduce_hook",
|
||||
"fp16_compress_hook",
|
||||
"bf16_compress_hook",
|
||||
"fp16_compress_wrapper",
|
||||
"bf16_compress_wrapper",
|
||||
]
|
||||
|
||||
|
||||
def _allreduce_fut(
|
||||
process_group: dist.ProcessGroup, tensor: torch.Tensor
|
||||
@ -44,7 +51,8 @@ def allreduce_hook(
|
||||
|
||||
|
||||
def fp16_compress_hook(
|
||||
process_group: dist.ProcessGroup, bucket: dist.GradBucket
|
||||
process_group: dist.ProcessGroup,
|
||||
bucket: dist.GradBucket,
|
||||
) -> torch.futures.Future[torch.Tensor]:
|
||||
"""
|
||||
Compress by casting ``GradBucket`` to ``torch.float16`` divided by process group size.
|
||||
@ -62,24 +70,37 @@ def fp16_compress_hook(
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
world_size = group_to_use.size()
|
||||
|
||||
compressed_tensor = bucket.buffer().to(torch.float16).div_(world_size)
|
||||
|
||||
fut = dist.all_reduce(
|
||||
compressed_tensor, group=group_to_use, async_op=True
|
||||
).get_future()
|
||||
buffer = (
|
||||
cast(Tuple[torch.Tensor, ...], bucket)[0]
|
||||
if isinstance(bucket, tuple)
|
||||
else bucket.buffer()
|
||||
)
|
||||
compressed_tensor = buffer.to(torch.float16).div_(world_size)
|
||||
|
||||
def decompress(fut):
|
||||
decompressed_tensor = bucket.buffer()
|
||||
decompressed_tensor = buffer
|
||||
# Decompress in place to reduce the peak memory.
|
||||
# See: https://github.com/pytorch/pytorch/issues/45968
|
||||
decompressed_tensor.copy_(fut.value()[0])
|
||||
value = fut if isinstance(fut, torch.Tensor) else fut.value()[0]
|
||||
decompressed_tensor.copy_(value)
|
||||
return decompressed_tensor
|
||||
|
||||
return fut.then(decompress)
|
||||
if torch._utils.is_compiling():
|
||||
grad = dist._functional_collectives.all_reduce(
|
||||
compressed_tensor, "sum", group_to_use
|
||||
)
|
||||
return decompress(grad)
|
||||
else:
|
||||
fut = dist.all_reduce(
|
||||
compressed_tensor, group=group_to_use, async_op=True
|
||||
).get_future()
|
||||
return fut.then(decompress)
|
||||
|
||||
|
||||
# TODO: create an internal helper function and extract the duplicate code in FP16_compress and BF16_compress.
|
||||
def bf16_compress_hook(
|
||||
process_group: dist.ProcessGroup, bucket: dist.GradBucket
|
||||
process_group: dist.ProcessGroup,
|
||||
bucket: dist.GradBucket,
|
||||
) -> torch.futures.Future[torch.Tensor]:
|
||||
"""
|
||||
Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
|
||||
@ -98,20 +119,31 @@ def bf16_compress_hook(
|
||||
group_to_use = process_group if process_group is not None else dist.group.WORLD
|
||||
world_size = group_to_use.size()
|
||||
|
||||
compressed_tensor = bucket.buffer().to(torch.bfloat16).div_(world_size)
|
||||
|
||||
fut = dist.all_reduce(
|
||||
compressed_tensor, group=group_to_use, async_op=True
|
||||
).get_future()
|
||||
buffer = (
|
||||
cast(Tuple[torch.Tensor, ...], bucket)[0]
|
||||
if isinstance(bucket, tuple)
|
||||
else bucket.buffer()
|
||||
)
|
||||
compressed_tensor = buffer.to(torch.bfloat16).div_(world_size)
|
||||
|
||||
def decompress(fut):
|
||||
decompressed_tensor = bucket.buffer()
|
||||
decompressed_tensor = buffer
|
||||
# Decompress in place to reduce the peak memory.
|
||||
# See: https://github.com/pytorch/pytorch/issues/45968
|
||||
decompressed_tensor.copy_(fut.value()[0])
|
||||
value = fut if isinstance(fut, torch.Tensor) else fut.value()[0]
|
||||
decompressed_tensor.copy_(value)
|
||||
return decompressed_tensor
|
||||
|
||||
return fut.then(decompress)
|
||||
if torch._utils.is_compiling():
|
||||
grad = dist._functional_collectives.all_reduce(
|
||||
compressed_tensor, "sum", group_to_use
|
||||
)
|
||||
return decompress(grad)
|
||||
else:
|
||||
fut = dist.all_reduce(
|
||||
compressed_tensor, group=group_to_use, async_op=True
|
||||
).get_future()
|
||||
return fut.then(decompress)
|
||||
|
||||
|
||||
def fp16_compress_wrapper(
|
||||
@ -151,6 +183,7 @@ def fp16_compress_wrapper(
|
||||
|
||||
return fp16_compress_wrapper_hook
|
||||
|
||||
|
||||
def bf16_compress_wrapper(
|
||||
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
|
||||
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
|
||||
|
||||
@ -11,14 +11,14 @@ from collections import defaultdict, deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, List, Optional, Type
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.autograd import Function, Variable
|
||||
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
|
||||
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
RPC_AVAILABLE = False
|
||||
if dist.is_available():
|
||||
@ -815,6 +815,8 @@ class DistributedDataParallel(Module, Joinable):
|
||||
param_to_name_mapping,
|
||||
static_graph,
|
||||
)
|
||||
self._comm_hooks: List[Tuple[Callable, object]] = []
|
||||
|
||||
if self.mixed_precision is not None:
|
||||
_setup_mixed_precision_params(self.mixed_precision, self.module)
|
||||
_cast_buffers(self.mixed_precision, self.module)
|
||||
@ -864,6 +866,50 @@ class DistributedDataParallel(Module, Joinable):
|
||||
|
||||
self._lazy_init_ran = False
|
||||
|
||||
# Register the AccumulateGrad post hooks if optimize_ddp is
|
||||
# True. The hooks will be deregistered if compiled_autograd is not
|
||||
# enabled.
|
||||
self._accum_grad_hooks: List[RemovableHandle] = []
|
||||
optimize_ddp = torch._dynamo.config._get_optimize_ddp_mode()
|
||||
self._use_python_reducer = optimize_ddp in (
|
||||
"python_reducer",
|
||||
"python_reducer_without_compiled_forward",
|
||||
)
|
||||
self._force_to_disable_cpp_reducer = (
|
||||
optimize_ddp == "python_reducer_without_compiled_forward"
|
||||
)
|
||||
if self._use_python_reducer:
|
||||
self._register_accum_grad_hook()
|
||||
|
||||
def _register_accum_grad_hook(self):
|
||||
import torch.distributed._functional_collectives as fcol
|
||||
|
||||
def compiled_accum_grad_hook(
|
||||
param,
|
||||
*,
|
||||
param_index: int,
|
||||
):
|
||||
if not self.require_backward_grad_sync:
|
||||
return
|
||||
|
||||
if self._comm_hooks:
|
||||
for hook, state in self._comm_hooks:
|
||||
hook(state, (param.grad, param))
|
||||
else:
|
||||
gradient = param.grad / self.process_group.size()
|
||||
gradient = fcol.all_reduce(gradient, "sum", self.process_group)
|
||||
param.grad.copy_(gradient)
|
||||
|
||||
for index, param in enumerate(self._module_parameters):
|
||||
self._accum_grad_hooks.append(
|
||||
param.register_post_accumulate_grad_hook(
|
||||
functools.partial(
|
||||
compiled_accum_grad_hook,
|
||||
param_index=index,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def _delayed_all_reduce_hook(self, grad):
|
||||
world_size = dist.get_world_size(self.process_group)
|
||||
|
||||
@ -1355,8 +1401,11 @@ class DistributedDataParallel(Module, Joinable):
|
||||
DistributedDataParallel._active_ddp_module = None
|
||||
|
||||
def _run_ddp_forward(self, *inputs, **kwargs):
|
||||
with self._inside_ddp_forward():
|
||||
if self._use_python_reducer:
|
||||
return self.module(*inputs, **kwargs) # type: ignore[index]
|
||||
else:
|
||||
with self._inside_ddp_forward():
|
||||
return self.module(*inputs, **kwargs) # type: ignore[index]
|
||||
|
||||
def _clear_grad_buffer(self):
|
||||
# Making param.grad points to the grad buffers before backward is based on the
|
||||
@ -1385,9 +1434,24 @@ class DistributedDataParallel(Module, Joinable):
|
||||
self._setup_in_backward_optimizers()
|
||||
self._lazy_init_ran = True
|
||||
|
||||
def _should_disable_cpp_reducer(self) -> bool:
|
||||
return self._use_python_reducer and (
|
||||
torch._utils.is_compiling() or self._force_to_disable_cpp_reducer
|
||||
)
|
||||
|
||||
def _pre_forward(self, *inputs, **kwargs):
|
||||
if not self._lazy_init_ran:
|
||||
if self._should_disable_cpp_reducer():
|
||||
return inputs, kwargs
|
||||
|
||||
# Disable the python reducer if compiled_autograd is not enabled.
|
||||
if self._accum_grad_hooks:
|
||||
for index, h in enumerate(self._accum_grad_hooks):
|
||||
h.remove()
|
||||
self._accum_grad_hooks.clear()
|
||||
|
||||
if not self._lazy_init_ran and not torch._utils.is_compiling():
|
||||
self._lazy_init()
|
||||
|
||||
if self._delay_all_reduce_all_params:
|
||||
return inputs, kwargs
|
||||
|
||||
@ -1451,6 +1515,9 @@ class DistributedDataParallel(Module, Joinable):
|
||||
return inputs, kwargs
|
||||
|
||||
def _post_forward(self, output):
|
||||
if self._should_disable_cpp_reducer():
|
||||
return output
|
||||
|
||||
if self._delay_all_reduce_all_params:
|
||||
self._clear_grad_buffer()
|
||||
return output
|
||||
@ -1883,8 +1950,16 @@ class DistributedDataParallel(Module, Joinable):
|
||||
>>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
|
||||
"""
|
||||
self._check_comm_hook(hook)
|
||||
if hook.__name__ in ["bf16_compress_hook", "fp16_compress_hook"]:
|
||||
# If we pass None, then the hook will try to get the world size
|
||||
# by calling `dist.group.WORLD.size()`, which causes compilation
|
||||
# errors. So we pre-decode the process group and pass it to the
|
||||
# hook.
|
||||
if state is None:
|
||||
state = dist.group.WORLD
|
||||
assert self.logger is not None
|
||||
self.logger._set_comm_hook_name(hook.__qualname__)
|
||||
self._comm_hooks.append((hook, state))
|
||||
dist._register_comm_hook(self.reducer, state, hook)
|
||||
|
||||
def _register_builtin_comm_hook(self, comm_hook_type):
|
||||
|
||||
Reference in New Issue
Block a user