[DDP] Use compiled_autograd to trace DDP backward allreduce (#110662)

**Summary**
The reducer of `DistributedDataParallel`  is implemented with C++ and it is not easy to trace the allreduce launched in the reducer. This PR modifies `DistributedDataParallel` to launch one allreduce per gradient when `compiled_autograd` is enabled. The changes allow us to use `compiled_autograd` to trace the allreduce and later be optimized (fused) in the Inductor.

**Key Logic**
1. If `ddp_python_hook` is True, we assume `compiled_autograd` is used. `DistributedDataParallel` registers `compiled_accum_grad_hook` for all parameters.
2. In the first forward() call, if `DistributedDataParallel` is not compiled, all  `compiled_accum_grad_hook` are deregistered. If `DistributedDataParallel` is compiled, all `compiled_accum_grad_hook` will be compiled by `compiled_autograd`.
3.  `compiled_accum_grad_hook` launches an allreduce to reduce the gradient of the parameter.

**Bucketing**
The compiled backward is slow because there is no bucketing for the allreduces. We rely on Inductor to bucket the allreduces.

The bucketing is done in a separate PR.

Differential Revision: [D49428482](https://our.internmc.facebook.com/intern/diff/D49428482/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110662
Approved by: https://github.com/wconstab
This commit is contained in:
Chien-Chin Huang
2024-02-07 11:11:36 -08:00
committed by PyTorch MergeBot
parent 113506d2d4
commit 1d2382f141
8 changed files with 406 additions and 36 deletions

View File

@ -0,0 +1,211 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import os
import unittest
from copy import deepcopy
from typing import Callable, Optional
import torch
import torch.distributed as dist
from torch import _inductor as inductor, nn
from torch._dynamo import compiled_autograd
from torch.distributed._composable.replicate import replicate
from torch.distributed.algorithms.ddp_comm_hooks import (
default_hooks as ddp_default_hooks,
)
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
skip_if_lt_x_gpu,
skip_if_rocm,
)
from torch.testing._internal.common_utils import run_tests
from torch.utils._triton import has_triton
from torch.utils.checkpoint import checkpoint
DIM = 2000
class Net(nn.Module):
def __init__(self, checkpoint=False):
super().__init__()
self.fc1 = nn.Linear(DIM, DIM)
self.fc2 = nn.Linear(DIM, DIM)
self.fc3 = nn.Linear(DIM, DIM)
self.fc4 = nn.Linear(DIM, DIM)
self.use_checkpoint = checkpoint
def forward(self, x):
if self.use_checkpoint:
_fc1 = checkpoint(self.fc1, x, use_reentrant=False)
else:
_fc1 = self.fc1(x)
return self.fc4(self.fc3(self.fc2(_fc1)))
def compiler_fn(no_inductor):
def _compiler_fn(gm):
def inner_compiler(gm_, example_inputs_):
if no_inductor:
return gm_
else:
return inductor.compile(gm_, example_inputs_)
gm = torch.compile(gm, fullgraph=True, backend=inner_compiler)
return gm
return _compiler_fn
class ReplicateTest(MultiProcessTestCase):
@property
def world_size(self) -> int:
return 2
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
def tearDown(self):
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
def _test_compile(
self,
*,
use_gpu: bool,
no_sync: bool,
setup_func: Optional[Callable] = None,
no_inductor: bool = False,
no_compile_forward: bool = False,
):
backend = "nccl" if use_gpu else "gloo"
dist.init_process_group(
backend=backend,
rank=self.rank,
world_size=self.world_size,
store=dist.FileStore(self.file_name, self.world_size),
)
if use_gpu:
torch.cuda.set_device(f"cuda:{self.rank}")
device = torch.device("cuda")
else:
device = torch.device("cpu")
torch._dynamo.config.optimize_ddp = (
"python_reducer_without_compiled_forward"
if no_compile_forward
else "python_reducer"
)
torch.manual_seed(123)
model = Net().to(device)
input = torch.randn([1, DIM], device=device)
compiled_model = torch.compile(replicate(deepcopy(model)), fullgraph=True)
compiled_optim = torch.optim.Adam(compiled_model.parameters())
model = replicate(model)
optim = torch.optim.Adam(model.parameters())
if setup_func:
setup_func(model, compiled_model)
# Run multiple iterations so that we could test no_sync
for i in range(2):
# Setting a different random seed so that if the allreduces are not
# executed correctly, the gradients won't be correct compared to the
# eager DDP.
torch.manual_seed(123 + self.rank + i)
input = torch.randn([1, DIM], device=device)
if no_sync and i % 2 == 0:
context = replicate.state(model)._ddp.no_sync()
else:
context = contextlib.nullcontext()
with context:
loss = model(input).sum()
loss.backward()
compiled_m = getattr(compiled_model, "_orig_mod", compiled_model)
if no_sync and i % 2 == 0:
context = replicate.state(compiled_m)._ddp.no_sync()
else:
context = contextlib.nullcontext()
with context:
with compiled_autograd.enable(compiler_fn(no_inductor)):
compiled_loss = compiled_model(input).sum()
compiled_loss.backward()
if not no_sync or i % 2 == 1:
for p1, p2 in zip(model.parameters(), compiled_model.parameters()):
self.assertEqual(p1.grad, p2.grad)
compiled_optim.step()
# Right now we have to use `set_to_none=False`, otherwise
# the backward will be recompiled every iteration.
# With `set_to_none=False`, it will only be recompiled once.
# https://github.com/pytorch/pytorch/issues/118435
compiled_optim.zero_grad(set_to_none=False)
optim.step()
optim.zero_grad()
self.assertEqual(tuple(model.parameters()), tuple(compiled_model.parameters()))
def test_compile_cpu(self):
self._test_compile(use_gpu=False, no_sync=False)
def test_compile_cpu_no_sync(self):
self._test_compile(use_gpu=False, no_sync=True)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm
@skip_if_lt_x_gpu(2)
def test_compile_gpu(self):
self._test_compile(use_gpu=True, no_sync=False)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm
@skip_if_lt_x_gpu(2)
def test_compile_bf16(self):
def setup(model, compiled_model) -> None:
replicate.state(model)._ddp.register_comm_hook(
None, ddp_default_hooks.bf16_compress_hook
)
compiled_m = compiled_model._orig_mod
replicate.state(compiled_m)._ddp.register_comm_hook(
None, ddp_default_hooks.bf16_compress_hook
)
self._test_compile(
use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True
)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm
@skip_if_lt_x_gpu(2)
def test_compile_fp16(self):
def setup(model, compiled_model) -> None:
replicate.state(model)._ddp.register_comm_hook(
None, ddp_default_hooks.fp16_compress_hook
)
compiled_m = compiled_model._orig_mod
replicate.state(compiled_m)._ddp.register_comm_hook(
None, ddp_default_hooks.fp16_compress_hook
)
# TODO: figure out why we need to disable Inductor to avoid test errors.
self._test_compile(
use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True
)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_rocm
@skip_if_lt_x_gpu(2)
def test_compile_backward_only(self):
self._test_compile(use_gpu=True, no_sync=False, no_compile_forward=True)
if __name__ == "__main__":
run_tests()

View File

@ -5,7 +5,7 @@ import re
import sys
import tempfile
from os.path import abspath, dirname
from typing import Any, Dict, Optional, Set, Type, TYPE_CHECKING
from typing import Any, Dict, Optional, Set, Type, TYPE_CHECKING, Union
import torch
@ -220,12 +220,50 @@ force_unspec_int_unbacked_size_like_on_torchrec_kjt = False
# false_fn produces code with identical guards.
enforce_cond_guards_match = True
# Automatically split model graph into pieces to match DDP bucket sizes
# to allow DDP comm/compute overlap. Disable to allow DDP models to
# run without graph-breaks, but also without comm/compute overlap.
# set TORCH_LOGS env to include any of 'dynamo', 'distributed', or
# 'dist_ddp' for more info about optimize_ddp behavior.
optimize_ddp = True
# Specify how to optimize a compiiled DDP module. The flag accepts a bollean
# value or a string. There are 4 modes.
# 1. "ddp_optimizer" (or True): with "ddp_ptimizer", Dynamo will automatically
# split model graph into pieces to match DDP bucket sizes to allow DDP
# comm/compute overlap.
# 2. "python_reducer" (experimental): this optimization requires the usage
# of compiled_autograd. With "python_reducer", DDP will disable the C++ reducer
# and use the Python reducer to allow compiled_autograd to trace the
# communication and allow comm/compute overlap without graph-breaks.
# 3. "python_reducer_without_compiled_forward" (experimental): this mode is
# similar to "python_reducer". One should only use this optimization mode
# when compiled_autograd is used but the DDP module is not compiled.
# 4. "no_optimization" (or False): Dynamo won't split the model graph, nor
# will Python reducer be used. With this mode, there will be no graph-breaks
# and the original DDP C++ reducer will be used. There will no comm/compute
# overlap. This mode CANNOT be used with compiled_autograd.
# Note that to avoid breaking the existing usage, mode 1 and mode 4 can be
# specified with a boolean value. True is using ddp_optimizer and False is
# no optimization.
optimize_ddp: Union[bool, str] = True
_ddp_optimization_mode = [
"ddp_optimizer",
"python_reducer", # experimental mode
"python_reducer_without_compiled_forward", # experimental mode
"no_optimization",
]
def _get_optimize_ddp_mode():
m = sys.modules[__name__]
if isinstance(m.optimize_ddp, bool):
if m.optimize_ddp:
mode = "ddp_optimizer"
else:
mode = "no_optimization"
elif isinstance(m.optimize_ddp, str):
mode = m.optimize_ddp
else:
raise ValueError(f"Invalid type, {type(optimize_ddp)=}")
assert mode in m._ddp_optimization_mode, f"Invalid mode {mode=}"
return mode
# If True, delays DDPOptimizer submodule compilation to 1st run of the model,
# so that real tensor strides are used in all submodules

View File

@ -881,7 +881,7 @@ def catch_errors_wrapper(callback, hooks: Hooks):
if frame.f_code.co_filename == "<string>" and frame.f_code.co_name == "__new__":
# nametuple constructor
return None
if config.optimize_ddp:
if config._get_optimize_ddp_mode() == "ddp_optimizer":
ddp_module = DistributedDataParallel._get_active_ddp_module()
if ddp_module:
with compile_lock:

View File

@ -188,6 +188,9 @@ if torch.distributed.is_available():
"torch.distributed.tensor.parallel._data_parallel_utils",
"torch.distributed.tensor.parallel._utils",
"torch.distributed.tensor.parallel.style",
# we have to add replicate to LEGACY_MOD_INLINELIST to ensure
# the forward_hook won't be ignored.
"torch.distributed._composable.replicate",
}
@ -229,6 +232,7 @@ MOD_INLINELIST = {
if torch.distributed.is_available():
MOD_INLINELIST.add("torch.distributed")
MOD_INLINELIST.add("torch.distributed._functional_collectives")
MOD_INLINELIST.add("torch.distributed._composable.replicate")
@functools.lru_cache(None)

View File

@ -9,15 +9,18 @@ namespace utils {
// Turns lambda into a torch::autograd::FunctionPostHook.
class LambdaPostHook : public torch::autograd::FunctionPostHook {
using variable_list = std::vector<torch::autograd::Variable>;
using fn_type =
std::function<variable_list(const variable_list&, const variable_list&)>;
using compiled_fn_type = std::function<void(CompiledNodeArgs&)>;
public:
// The lambda function takes as arguments the outputs and inputs of the
// autograd function and can modify the outputs of the autograd function by
// returning a new output if needed.
/* implicit */ LambdaPostHook(
std::function<variable_list(const variable_list&, const variable_list&)>
fn)
: fn_(std::move(fn)) {}
/* implicit */ LambdaPostHook(fn_type fn) : fn_(std::move(fn)) {}
LambdaPostHook(fn_type fn, compiled_fn_type compiled_fn)
: fn_(std::move(fn)), compiled_fn_(std::move(compiled_fn)) {}
variable_list operator()(
const variable_list& outputs,
@ -25,8 +28,11 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
return fn_(outputs, inputs);
}
void compiled_args(CompiledNodeArgs& args) override {}
protected:
std::function<variable_list(const variable_list&, const variable_list&)> fn_;
compiled_fn_type compiled_fn_;
};
} // namespace utils

View File

@ -193,6 +193,9 @@ Reducer::Reducer(
#endif
this->autograd_hook(variable_index);
return outputs;
},
[=](torch::autograd::CompiledNodeArgs& args) {
// Make post_hook an noop if compiled_autograds is enabled.
})),
grad_accumulator);

View File

@ -1,9 +1,16 @@
from typing import Any, Callable
from typing import Any, Callable, cast, Tuple
import torch
import torch.distributed as dist
__all__ = ["allreduce_hook", "fp16_compress_hook", "bf16_compress_hook", "fp16_compress_wrapper", "bf16_compress_wrapper"]
__all__ = [
"allreduce_hook",
"fp16_compress_hook",
"bf16_compress_hook",
"fp16_compress_wrapper",
"bf16_compress_wrapper",
]
def _allreduce_fut(
process_group: dist.ProcessGroup, tensor: torch.Tensor
@ -44,7 +51,8 @@ def allreduce_hook(
def fp16_compress_hook(
process_group: dist.ProcessGroup, bucket: dist.GradBucket
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
) -> torch.futures.Future[torch.Tensor]:
"""
Compress by casting ``GradBucket`` to ``torch.float16`` divided by process group size.
@ -62,24 +70,37 @@ def fp16_compress_hook(
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()
compressed_tensor = bucket.buffer().to(torch.float16).div_(world_size)
fut = dist.all_reduce(
compressed_tensor, group=group_to_use, async_op=True
).get_future()
buffer = (
cast(Tuple[torch.Tensor, ...], bucket)[0]
if isinstance(bucket, tuple)
else bucket.buffer()
)
compressed_tensor = buffer.to(torch.float16).div_(world_size)
def decompress(fut):
decompressed_tensor = bucket.buffer()
decompressed_tensor = buffer
# Decompress in place to reduce the peak memory.
# See: https://github.com/pytorch/pytorch/issues/45968
decompressed_tensor.copy_(fut.value()[0])
value = fut if isinstance(fut, torch.Tensor) else fut.value()[0]
decompressed_tensor.copy_(value)
return decompressed_tensor
return fut.then(decompress)
if torch._utils.is_compiling():
grad = dist._functional_collectives.all_reduce(
compressed_tensor, "sum", group_to_use
)
return decompress(grad)
else:
fut = dist.all_reduce(
compressed_tensor, group=group_to_use, async_op=True
).get_future()
return fut.then(decompress)
# TODO: create an internal helper function and extract the duplicate code in FP16_compress and BF16_compress.
def bf16_compress_hook(
process_group: dist.ProcessGroup, bucket: dist.GradBucket
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
) -> torch.futures.Future[torch.Tensor]:
"""
Warning: This API is experimental, and it requires NCCL version later than 2.9.6.
@ -98,20 +119,31 @@ def bf16_compress_hook(
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()
compressed_tensor = bucket.buffer().to(torch.bfloat16).div_(world_size)
fut = dist.all_reduce(
compressed_tensor, group=group_to_use, async_op=True
).get_future()
buffer = (
cast(Tuple[torch.Tensor, ...], bucket)[0]
if isinstance(bucket, tuple)
else bucket.buffer()
)
compressed_tensor = buffer.to(torch.bfloat16).div_(world_size)
def decompress(fut):
decompressed_tensor = bucket.buffer()
decompressed_tensor = buffer
# Decompress in place to reduce the peak memory.
# See: https://github.com/pytorch/pytorch/issues/45968
decompressed_tensor.copy_(fut.value()[0])
value = fut if isinstance(fut, torch.Tensor) else fut.value()[0]
decompressed_tensor.copy_(value)
return decompressed_tensor
return fut.then(decompress)
if torch._utils.is_compiling():
grad = dist._functional_collectives.all_reduce(
compressed_tensor, "sum", group_to_use
)
return decompress(grad)
else:
fut = dist.all_reduce(
compressed_tensor, group=group_to_use, async_op=True
).get_future()
return fut.then(decompress)
def fp16_compress_wrapper(
@ -151,6 +183,7 @@ def fp16_compress_wrapper(
return fp16_compress_wrapper_hook
def bf16_compress_wrapper(
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:

View File

@ -11,14 +11,14 @@ from collections import defaultdict, deque
from contextlib import contextmanager
from dataclasses import dataclass, fields, is_dataclass
from enum import auto, Enum
from typing import Any, Callable, List, Optional, Type
from typing import Any, Callable, List, Optional, Tuple, Type
import torch
import torch.distributed as dist
from torch.autograd import Function, Variable
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.utils.hooks import RemovableHandle
RPC_AVAILABLE = False
if dist.is_available():
@ -815,6 +815,8 @@ class DistributedDataParallel(Module, Joinable):
param_to_name_mapping,
static_graph,
)
self._comm_hooks: List[Tuple[Callable, object]] = []
if self.mixed_precision is not None:
_setup_mixed_precision_params(self.mixed_precision, self.module)
_cast_buffers(self.mixed_precision, self.module)
@ -864,6 +866,50 @@ class DistributedDataParallel(Module, Joinable):
self._lazy_init_ran = False
# Register the AccumulateGrad post hooks if optimize_ddp is
# True. The hooks will be deregistered if compiled_autograd is not
# enabled.
self._accum_grad_hooks: List[RemovableHandle] = []
optimize_ddp = torch._dynamo.config._get_optimize_ddp_mode()
self._use_python_reducer = optimize_ddp in (
"python_reducer",
"python_reducer_without_compiled_forward",
)
self._force_to_disable_cpp_reducer = (
optimize_ddp == "python_reducer_without_compiled_forward"
)
if self._use_python_reducer:
self._register_accum_grad_hook()
def _register_accum_grad_hook(self):
import torch.distributed._functional_collectives as fcol
def compiled_accum_grad_hook(
param,
*,
param_index: int,
):
if not self.require_backward_grad_sync:
return
if self._comm_hooks:
for hook, state in self._comm_hooks:
hook(state, (param.grad, param))
else:
gradient = param.grad / self.process_group.size()
gradient = fcol.all_reduce(gradient, "sum", self.process_group)
param.grad.copy_(gradient)
for index, param in enumerate(self._module_parameters):
self._accum_grad_hooks.append(
param.register_post_accumulate_grad_hook(
functools.partial(
compiled_accum_grad_hook,
param_index=index,
)
)
)
def _delayed_all_reduce_hook(self, grad):
world_size = dist.get_world_size(self.process_group)
@ -1355,8 +1401,11 @@ class DistributedDataParallel(Module, Joinable):
DistributedDataParallel._active_ddp_module = None
def _run_ddp_forward(self, *inputs, **kwargs):
with self._inside_ddp_forward():
if self._use_python_reducer:
return self.module(*inputs, **kwargs) # type: ignore[index]
else:
with self._inside_ddp_forward():
return self.module(*inputs, **kwargs) # type: ignore[index]
def _clear_grad_buffer(self):
# Making param.grad points to the grad buffers before backward is based on the
@ -1385,9 +1434,24 @@ class DistributedDataParallel(Module, Joinable):
self._setup_in_backward_optimizers()
self._lazy_init_ran = True
def _should_disable_cpp_reducer(self) -> bool:
return self._use_python_reducer and (
torch._utils.is_compiling() or self._force_to_disable_cpp_reducer
)
def _pre_forward(self, *inputs, **kwargs):
if not self._lazy_init_ran:
if self._should_disable_cpp_reducer():
return inputs, kwargs
# Disable the python reducer if compiled_autograd is not enabled.
if self._accum_grad_hooks:
for index, h in enumerate(self._accum_grad_hooks):
h.remove()
self._accum_grad_hooks.clear()
if not self._lazy_init_ran and not torch._utils.is_compiling():
self._lazy_init()
if self._delay_all_reduce_all_params:
return inputs, kwargs
@ -1451,6 +1515,9 @@ class DistributedDataParallel(Module, Joinable):
return inputs, kwargs
def _post_forward(self, output):
if self._should_disable_cpp_reducer():
return output
if self._delay_all_reduce_all_params:
self._clear_grad_buffer()
return output
@ -1883,8 +1950,16 @@ class DistributedDataParallel(Module, Joinable):
>>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
"""
self._check_comm_hook(hook)
if hook.__name__ in ["bf16_compress_hook", "fp16_compress_hook"]:
# If we pass None, then the hook will try to get the world size
# by calling `dist.group.WORLD.size()`, which causes compilation
# errors. So we pre-decode the process group and pass it to the
# hook.
if state is None:
state = dist.group.WORLD
assert self.logger is not None
self.logger._set_comm_hook_name(hook.__qualname__)
self._comm_hooks.append((hook, state))
dist._register_comm_hook(self.reducer, state, hook)
def _register_builtin_comm_hook(self, comm_hook_type):