diff --git a/test/distributed/_composable/test_replicate_with_compiler.py b/test/distributed/_composable/test_replicate_with_compiler.py new file mode 100644 index 000000000000..994b3a80761f --- /dev/null +++ b/test/distributed/_composable/test_replicate_with_compiler.py @@ -0,0 +1,211 @@ +# Owner(s): ["oncall: distributed"] + +import contextlib +import os +import unittest +from copy import deepcopy +from typing import Callable, Optional + +import torch +import torch.distributed as dist +from torch import _inductor as inductor, nn +from torch._dynamo import compiled_autograd +from torch.distributed._composable.replicate import replicate +from torch.distributed.algorithms.ddp_comm_hooks import ( + default_hooks as ddp_default_hooks, +) +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + skip_if_lt_x_gpu, + skip_if_rocm, +) +from torch.testing._internal.common_utils import run_tests +from torch.utils._triton import has_triton +from torch.utils.checkpoint import checkpoint + + +DIM = 2000 + + +class Net(nn.Module): + def __init__(self, checkpoint=False): + super().__init__() + self.fc1 = nn.Linear(DIM, DIM) + self.fc2 = nn.Linear(DIM, DIM) + self.fc3 = nn.Linear(DIM, DIM) + self.fc4 = nn.Linear(DIM, DIM) + self.use_checkpoint = checkpoint + + def forward(self, x): + if self.use_checkpoint: + _fc1 = checkpoint(self.fc1, x, use_reentrant=False) + else: + _fc1 = self.fc1(x) + return self.fc4(self.fc3(self.fc2(_fc1))) + + +def compiler_fn(no_inductor): + def _compiler_fn(gm): + def inner_compiler(gm_, example_inputs_): + if no_inductor: + return gm_ + else: + return inductor.compile(gm_, example_inputs_) + + gm = torch.compile(gm, fullgraph=True, backend=inner_compiler) + return gm + + return _compiler_fn + + +class ReplicateTest(MultiProcessTestCase): + @property + def world_size(self) -> int: + return 2 + + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + def _test_compile( + self, + *, + use_gpu: bool, + no_sync: bool, + setup_func: Optional[Callable] = None, + no_inductor: bool = False, + no_compile_forward: bool = False, + ): + backend = "nccl" if use_gpu else "gloo" + dist.init_process_group( + backend=backend, + rank=self.rank, + world_size=self.world_size, + store=dist.FileStore(self.file_name, self.world_size), + ) + if use_gpu: + torch.cuda.set_device(f"cuda:{self.rank}") + device = torch.device("cuda") + else: + device = torch.device("cpu") + + torch._dynamo.config.optimize_ddp = ( + "python_reducer_without_compiled_forward" + if no_compile_forward + else "python_reducer" + ) + torch.manual_seed(123) + model = Net().to(device) + input = torch.randn([1, DIM], device=device) + + compiled_model = torch.compile(replicate(deepcopy(model)), fullgraph=True) + compiled_optim = torch.optim.Adam(compiled_model.parameters()) + model = replicate(model) + optim = torch.optim.Adam(model.parameters()) + + if setup_func: + setup_func(model, compiled_model) + + # Run multiple iterations so that we could test no_sync + for i in range(2): + # Setting a different random seed so that if the allreduces are not + # executed correctly, the gradients won't be correct compared to the + # eager DDP. + torch.manual_seed(123 + self.rank + i) + input = torch.randn([1, DIM], device=device) + + if no_sync and i % 2 == 0: + context = replicate.state(model)._ddp.no_sync() + else: + context = contextlib.nullcontext() + with context: + loss = model(input).sum() + loss.backward() + + compiled_m = getattr(compiled_model, "_orig_mod", compiled_model) + if no_sync and i % 2 == 0: + context = replicate.state(compiled_m)._ddp.no_sync() + else: + context = contextlib.nullcontext() + with context: + with compiled_autograd.enable(compiler_fn(no_inductor)): + compiled_loss = compiled_model(input).sum() + compiled_loss.backward() + + if not no_sync or i % 2 == 1: + for p1, p2 in zip(model.parameters(), compiled_model.parameters()): + self.assertEqual(p1.grad, p2.grad) + compiled_optim.step() + # Right now we have to use `set_to_none=False`, otherwise + # the backward will be recompiled every iteration. + # With `set_to_none=False`, it will only be recompiled once. + # https://github.com/pytorch/pytorch/issues/118435 + compiled_optim.zero_grad(set_to_none=False) + optim.step() + optim.zero_grad() + + self.assertEqual(tuple(model.parameters()), tuple(compiled_model.parameters())) + + def test_compile_cpu(self): + self._test_compile(use_gpu=False, no_sync=False) + + def test_compile_cpu_no_sync(self): + self._test_compile(use_gpu=False, no_sync=True) + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @skip_if_rocm + @skip_if_lt_x_gpu(2) + def test_compile_gpu(self): + self._test_compile(use_gpu=True, no_sync=False) + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @skip_if_rocm + @skip_if_lt_x_gpu(2) + def test_compile_bf16(self): + def setup(model, compiled_model) -> None: + replicate.state(model)._ddp.register_comm_hook( + None, ddp_default_hooks.bf16_compress_hook + ) + compiled_m = compiled_model._orig_mod + replicate.state(compiled_m)._ddp.register_comm_hook( + None, ddp_default_hooks.bf16_compress_hook + ) + + self._test_compile( + use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True + ) + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @skip_if_rocm + @skip_if_lt_x_gpu(2) + def test_compile_fp16(self): + def setup(model, compiled_model) -> None: + replicate.state(model)._ddp.register_comm_hook( + None, ddp_default_hooks.fp16_compress_hook + ) + compiled_m = compiled_model._orig_mod + replicate.state(compiled_m)._ddp.register_comm_hook( + None, ddp_default_hooks.fp16_compress_hook + ) + + # TODO: figure out why we need to disable Inductor to avoid test errors. + self._test_compile( + use_gpu=True, no_sync=False, setup_func=setup, no_inductor=True + ) + + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + @skip_if_rocm + @skip_if_lt_x_gpu(2) + def test_compile_backward_only(self): + self._test_compile(use_gpu=True, no_sync=False, no_compile_forward=True) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 0e91eb007a9a..c96bddf77525 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -5,7 +5,7 @@ import re import sys import tempfile from os.path import abspath, dirname -from typing import Any, Dict, Optional, Set, Type, TYPE_CHECKING +from typing import Any, Dict, Optional, Set, Type, TYPE_CHECKING, Union import torch @@ -220,12 +220,50 @@ force_unspec_int_unbacked_size_like_on_torchrec_kjt = False # false_fn produces code with identical guards. enforce_cond_guards_match = True -# Automatically split model graph into pieces to match DDP bucket sizes -# to allow DDP comm/compute overlap. Disable to allow DDP models to -# run without graph-breaks, but also without comm/compute overlap. -# set TORCH_LOGS env to include any of 'dynamo', 'distributed', or -# 'dist_ddp' for more info about optimize_ddp behavior. -optimize_ddp = True +# Specify how to optimize a compiiled DDP module. The flag accepts a bollean +# value or a string. There are 4 modes. +# 1. "ddp_optimizer" (or True): with "ddp_ptimizer", Dynamo will automatically +# split model graph into pieces to match DDP bucket sizes to allow DDP +# comm/compute overlap. +# 2. "python_reducer" (experimental): this optimization requires the usage +# of compiled_autograd. With "python_reducer", DDP will disable the C++ reducer +# and use the Python reducer to allow compiled_autograd to trace the +# communication and allow comm/compute overlap without graph-breaks. +# 3. "python_reducer_without_compiled_forward" (experimental): this mode is +# similar to "python_reducer". One should only use this optimization mode +# when compiled_autograd is used but the DDP module is not compiled. +# 4. "no_optimization" (or False): Dynamo won't split the model graph, nor +# will Python reducer be used. With this mode, there will be no graph-breaks +# and the original DDP C++ reducer will be used. There will no comm/compute +# overlap. This mode CANNOT be used with compiled_autograd. +# Note that to avoid breaking the existing usage, mode 1 and mode 4 can be +# specified with a boolean value. True is using ddp_optimizer and False is +# no optimization. +optimize_ddp: Union[bool, str] = True + +_ddp_optimization_mode = [ + "ddp_optimizer", + "python_reducer", # experimental mode + "python_reducer_without_compiled_forward", # experimental mode + "no_optimization", +] + + +def _get_optimize_ddp_mode(): + m = sys.modules[__name__] + if isinstance(m.optimize_ddp, bool): + if m.optimize_ddp: + mode = "ddp_optimizer" + else: + mode = "no_optimization" + elif isinstance(m.optimize_ddp, str): + mode = m.optimize_ddp + else: + raise ValueError(f"Invalid type, {type(optimize_ddp)=}") + + assert mode in m._ddp_optimization_mode, f"Invalid mode {mode=}" + return mode + # If True, delays DDPOptimizer submodule compilation to 1st run of the model, # so that real tensor strides are used in all submodules diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index b0758b15fae6..6b311653fc4c 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -881,7 +881,7 @@ def catch_errors_wrapper(callback, hooks: Hooks): if frame.f_code.co_filename == "" and frame.f_code.co_name == "__new__": # nametuple constructor return None - if config.optimize_ddp: + if config._get_optimize_ddp_mode() == "ddp_optimizer": ddp_module = DistributedDataParallel._get_active_ddp_module() if ddp_module: with compile_lock: diff --git a/torch/_dynamo/skipfiles.py b/torch/_dynamo/skipfiles.py index ea63242badcf..27cac661d526 100644 --- a/torch/_dynamo/skipfiles.py +++ b/torch/_dynamo/skipfiles.py @@ -188,6 +188,9 @@ if torch.distributed.is_available(): "torch.distributed.tensor.parallel._data_parallel_utils", "torch.distributed.tensor.parallel._utils", "torch.distributed.tensor.parallel.style", + # we have to add replicate to LEGACY_MOD_INLINELIST to ensure + # the forward_hook won't be ignored. + "torch.distributed._composable.replicate", } @@ -229,6 +232,7 @@ MOD_INLINELIST = { if torch.distributed.is_available(): MOD_INLINELIST.add("torch.distributed") MOD_INLINELIST.add("torch.distributed._functional_collectives") + MOD_INLINELIST.add("torch.distributed._composable.replicate") @functools.lru_cache(None) diff --git a/torch/csrc/autograd/utils/lambda_post_hook.h b/torch/csrc/autograd/utils/lambda_post_hook.h index f22a22312159..ade2389363eb 100644 --- a/torch/csrc/autograd/utils/lambda_post_hook.h +++ b/torch/csrc/autograd/utils/lambda_post_hook.h @@ -9,15 +9,18 @@ namespace utils { // Turns lambda into a torch::autograd::FunctionPostHook. class LambdaPostHook : public torch::autograd::FunctionPostHook { using variable_list = std::vector; + using fn_type = + std::function; + using compiled_fn_type = std::function; public: // The lambda function takes as arguments the outputs and inputs of the // autograd function and can modify the outputs of the autograd function by // returning a new output if needed. - /* implicit */ LambdaPostHook( - std::function - fn) - : fn_(std::move(fn)) {} + /* implicit */ LambdaPostHook(fn_type fn) : fn_(std::move(fn)) {} + + LambdaPostHook(fn_type fn, compiled_fn_type compiled_fn) + : fn_(std::move(fn)), compiled_fn_(std::move(compiled_fn)) {} variable_list operator()( const variable_list& outputs, @@ -25,8 +28,11 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook { return fn_(outputs, inputs); } + void compiled_args(CompiledNodeArgs& args) override {} + protected: std::function fn_; + compiled_fn_type compiled_fn_; }; } // namespace utils diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 9a3960f0fcc6..c7e8461760d5 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -193,6 +193,9 @@ Reducer::Reducer( #endif this->autograd_hook(variable_index); return outputs; + }, + [=](torch::autograd::CompiledNodeArgs& args) { + // Make post_hook an noop if compiled_autograds is enabled. })), grad_accumulator); diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index abdca49f87bf..bff55327e847 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -1,9 +1,16 @@ -from typing import Any, Callable +from typing import Any, Callable, cast, Tuple import torch import torch.distributed as dist -__all__ = ["allreduce_hook", "fp16_compress_hook", "bf16_compress_hook", "fp16_compress_wrapper", "bf16_compress_wrapper"] +__all__ = [ + "allreduce_hook", + "fp16_compress_hook", + "bf16_compress_hook", + "fp16_compress_wrapper", + "bf16_compress_wrapper", +] + def _allreduce_fut( process_group: dist.ProcessGroup, tensor: torch.Tensor @@ -44,7 +51,8 @@ def allreduce_hook( def fp16_compress_hook( - process_group: dist.ProcessGroup, bucket: dist.GradBucket + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, ) -> torch.futures.Future[torch.Tensor]: """ Compress by casting ``GradBucket`` to ``torch.float16`` divided by process group size. @@ -62,24 +70,37 @@ def fp16_compress_hook( group_to_use = process_group if process_group is not None else dist.group.WORLD world_size = group_to_use.size() - compressed_tensor = bucket.buffer().to(torch.float16).div_(world_size) - - fut = dist.all_reduce( - compressed_tensor, group=group_to_use, async_op=True - ).get_future() + buffer = ( + cast(Tuple[torch.Tensor, ...], bucket)[0] + if isinstance(bucket, tuple) + else bucket.buffer() + ) + compressed_tensor = buffer.to(torch.float16).div_(world_size) def decompress(fut): - decompressed_tensor = bucket.buffer() + decompressed_tensor = buffer # Decompress in place to reduce the peak memory. # See: https://github.com/pytorch/pytorch/issues/45968 - decompressed_tensor.copy_(fut.value()[0]) + value = fut if isinstance(fut, torch.Tensor) else fut.value()[0] + decompressed_tensor.copy_(value) return decompressed_tensor - return fut.then(decompress) + if torch._utils.is_compiling(): + grad = dist._functional_collectives.all_reduce( + compressed_tensor, "sum", group_to_use + ) + return decompress(grad) + else: + fut = dist.all_reduce( + compressed_tensor, group=group_to_use, async_op=True + ).get_future() + return fut.then(decompress) + # TODO: create an internal helper function and extract the duplicate code in FP16_compress and BF16_compress. def bf16_compress_hook( - process_group: dist.ProcessGroup, bucket: dist.GradBucket + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, ) -> torch.futures.Future[torch.Tensor]: """ Warning: This API is experimental, and it requires NCCL version later than 2.9.6. @@ -98,20 +119,31 @@ def bf16_compress_hook( group_to_use = process_group if process_group is not None else dist.group.WORLD world_size = group_to_use.size() - compressed_tensor = bucket.buffer().to(torch.bfloat16).div_(world_size) - - fut = dist.all_reduce( - compressed_tensor, group=group_to_use, async_op=True - ).get_future() + buffer = ( + cast(Tuple[torch.Tensor, ...], bucket)[0] + if isinstance(bucket, tuple) + else bucket.buffer() + ) + compressed_tensor = buffer.to(torch.bfloat16).div_(world_size) def decompress(fut): - decompressed_tensor = bucket.buffer() + decompressed_tensor = buffer # Decompress in place to reduce the peak memory. # See: https://github.com/pytorch/pytorch/issues/45968 - decompressed_tensor.copy_(fut.value()[0]) + value = fut if isinstance(fut, torch.Tensor) else fut.value()[0] + decompressed_tensor.copy_(value) return decompressed_tensor - return fut.then(decompress) + if torch._utils.is_compiling(): + grad = dist._functional_collectives.all_reduce( + compressed_tensor, "sum", group_to_use + ) + return decompress(grad) + else: + fut = dist.all_reduce( + compressed_tensor, group=group_to_use, async_op=True + ).get_future() + return fut.then(decompress) def fp16_compress_wrapper( @@ -151,6 +183,7 @@ def fp16_compress_wrapper( return fp16_compress_wrapper_hook + def bf16_compress_wrapper( hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]] ) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]: diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index d527150e5f65..23800be6685c 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -11,14 +11,14 @@ from collections import defaultdict, deque from contextlib import contextmanager from dataclasses import dataclass, fields, is_dataclass from enum import auto, Enum -from typing import Any, Callable, List, Optional, Type +from typing import Any, Callable, List, Optional, Tuple, Type import torch import torch.distributed as dist from torch.autograd import Function, Variable from torch.distributed.algorithms.join import Join, Joinable, JoinHook - from torch.utils._pytree import tree_flatten, tree_unflatten +from torch.utils.hooks import RemovableHandle RPC_AVAILABLE = False if dist.is_available(): @@ -815,6 +815,8 @@ class DistributedDataParallel(Module, Joinable): param_to_name_mapping, static_graph, ) + self._comm_hooks: List[Tuple[Callable, object]] = [] + if self.mixed_precision is not None: _setup_mixed_precision_params(self.mixed_precision, self.module) _cast_buffers(self.mixed_precision, self.module) @@ -864,6 +866,50 @@ class DistributedDataParallel(Module, Joinable): self._lazy_init_ran = False + # Register the AccumulateGrad post hooks if optimize_ddp is + # True. The hooks will be deregistered if compiled_autograd is not + # enabled. + self._accum_grad_hooks: List[RemovableHandle] = [] + optimize_ddp = torch._dynamo.config._get_optimize_ddp_mode() + self._use_python_reducer = optimize_ddp in ( + "python_reducer", + "python_reducer_without_compiled_forward", + ) + self._force_to_disable_cpp_reducer = ( + optimize_ddp == "python_reducer_without_compiled_forward" + ) + if self._use_python_reducer: + self._register_accum_grad_hook() + + def _register_accum_grad_hook(self): + import torch.distributed._functional_collectives as fcol + + def compiled_accum_grad_hook( + param, + *, + param_index: int, + ): + if not self.require_backward_grad_sync: + return + + if self._comm_hooks: + for hook, state in self._comm_hooks: + hook(state, (param.grad, param)) + else: + gradient = param.grad / self.process_group.size() + gradient = fcol.all_reduce(gradient, "sum", self.process_group) + param.grad.copy_(gradient) + + for index, param in enumerate(self._module_parameters): + self._accum_grad_hooks.append( + param.register_post_accumulate_grad_hook( + functools.partial( + compiled_accum_grad_hook, + param_index=index, + ) + ) + ) + def _delayed_all_reduce_hook(self, grad): world_size = dist.get_world_size(self.process_group) @@ -1355,8 +1401,11 @@ class DistributedDataParallel(Module, Joinable): DistributedDataParallel._active_ddp_module = None def _run_ddp_forward(self, *inputs, **kwargs): - with self._inside_ddp_forward(): + if self._use_python_reducer: return self.module(*inputs, **kwargs) # type: ignore[index] + else: + with self._inside_ddp_forward(): + return self.module(*inputs, **kwargs) # type: ignore[index] def _clear_grad_buffer(self): # Making param.grad points to the grad buffers before backward is based on the @@ -1385,9 +1434,24 @@ class DistributedDataParallel(Module, Joinable): self._setup_in_backward_optimizers() self._lazy_init_ran = True + def _should_disable_cpp_reducer(self) -> bool: + return self._use_python_reducer and ( + torch._utils.is_compiling() or self._force_to_disable_cpp_reducer + ) + def _pre_forward(self, *inputs, **kwargs): - if not self._lazy_init_ran: + if self._should_disable_cpp_reducer(): + return inputs, kwargs + + # Disable the python reducer if compiled_autograd is not enabled. + if self._accum_grad_hooks: + for index, h in enumerate(self._accum_grad_hooks): + h.remove() + self._accum_grad_hooks.clear() + + if not self._lazy_init_ran and not torch._utils.is_compiling(): self._lazy_init() + if self._delay_all_reduce_all_params: return inputs, kwargs @@ -1451,6 +1515,9 @@ class DistributedDataParallel(Module, Joinable): return inputs, kwargs def _post_forward(self, output): + if self._should_disable_cpp_reducer(): + return output + if self._delay_all_reduce_all_params: self._clear_grad_buffer() return output @@ -1883,8 +1950,16 @@ class DistributedDataParallel(Module, Joinable): >>> ddp.register_comm_hook(state=None, hook=encode_and_decode) """ self._check_comm_hook(hook) + if hook.__name__ in ["bf16_compress_hook", "fp16_compress_hook"]: + # If we pass None, then the hook will try to get the world size + # by calling `dist.group.WORLD.size()`, which causes compilation + # errors. So we pre-decode the process group and pass it to the + # hook. + if state is None: + state = dist.group.WORLD assert self.logger is not None self.logger._set_comm_hook_name(hook.__qualname__) + self._comm_hooks.append((hook, state)) dist._register_comm_hook(self.reducer, state, hook) def _register_builtin_comm_hook(self, comm_hook_type):