[ddp] propagate use_python_reducer to C++ reducer (#152735)

C++ Reducer is silently incorrect under CA, its implementation is no-oping the collective. I'm guessing that it was no-op'd because in DDP + python reducer, the C++ reducer is still being initialized.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152735
Approved by: https://github.com/fegin
ghstack dependencies: #153300, #152689
This commit is contained in:
Simon Fan
2025-05-15 11:20:34 -07:00
committed by PyTorch MergeBot
parent 1b4749f748
commit d1f1ff8610
7 changed files with 79 additions and 11 deletions

View File

@ -19,6 +19,7 @@ from string import Template
from unittest import mock
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import _inductor as inductor
@ -30,6 +31,7 @@ from torch._dynamo.utils import counters
from torch._inductor import config as inductor_config
from torch._inductor.test_case import run_tests, TestCase
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
@ -4161,6 +4163,54 @@ class CompiledAutograd1(torch.nn.Module):
first, second, third, fourth = fn(eager(), aot_eager())
self.assertIsNone(third)
@unittest.skipIf(
not torch.distributed.is_available(),
"FakePG relies on distributed build",
)
def test_ddp_cpp_reducer_error(self):
from torch.testing._internal.distributed.fake_pg import FakeStore
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
try:
model = torch.nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10))
model = DDP(model)
inputs = torch.randn(10, 10)
loss = model(inputs).sum()
with compiled_autograd._enable(compiler_fn), self.assertRaisesRegex(
RuntimeError,
(
r"Compiled autograd is not compatible with C\+\+ DDP Reducer, "
r'please use torch._dynamo.config.optimize_ddp="python_reducer"'
),
):
loss.backward()
finally:
dist.destroy_process_group()
@unittest.skipIf(
not torch.distributed.is_available(),
"FakePG relies on distributed build",
)
@config.patch(optimize_ddp="python_reducer")
def test_ddp_python_reducer(self):
from torch.testing._internal.distributed.fake_pg import FakeStore
store = FakeStore()
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
try:
model = torch.nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10))
model = DDP(model)
inputs = torch.randn(10, 10)
loss = model(inputs).sum()
with compiled_autograd._enable(compiler_fn):
# no error expected
loss.backward()
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
finally:
dist.destroy_process_group()
def load_test_module(name):
testdir = Path(__file__).absolute().parent.parent

View File

@ -51,6 +51,7 @@ class Reducer:
param_to_name_mapping: dict[int, str] = ...,
first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
skip_all_reduce_unused_params: bool = ...,
use_python_reducer: bool = ...,
) -> None: ...
def prepare_for_forward(self) -> None: ...
def prepare_for_backward(self, output: list[Tensor]) -> None: ...

View File

@ -27,7 +27,12 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
return fn_(outputs, inputs);
}
void compiled_args(CompiledNodeArgs& args) const override {}
void compiled_args(CompiledNodeArgs& args) const override {
if (compiled_fn_ != nullptr) {
return compiled_fn_(args);
}
return FunctionPostHook::compiled_args(args);
}
protected:
std::function<variable_list(const variable_list&, const variable_list&)> fn_;

View File

@ -560,7 +560,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
bool gradient_as_bucket_view,
std::unordered_map<size_t, std::string> param_to_name_mapping,
int64_t first_bucket_bytes_cap,
bool skip_all_reduce_unused_params) {
bool skip_all_reduce_unused_params,
bool use_python_reducer) {
// gil_scoped_release is not safe as a call_guard in init.
// https://github.com/pybind/pybind11/issues/5473
py::gil_scoped_release nogil{};
@ -575,7 +576,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
gradient_as_bucket_view,
std::move(param_to_name_mapping),
first_bucket_bytes_cap,
skip_all_reduce_unused_params);
skip_all_reduce_unused_params,
use_python_reducer);
}),
py::arg("params"),
py::arg("bucket_indices"),
@ -588,7 +590,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
py::arg("param_to_name_mapping") =
std::unordered_map<size_t, std::string>(),
py::arg("first_bucket_bytes_cap") = ::c10d::kDefaultFirstBucketBytes,
py::arg("skip_all_reduce_unused_params") = false)
py::arg("skip_all_reduce_unused_params") = false,
py::arg("use_python_reducer") = false)
.def(
"prepare_for_forward",
&::c10d::Reducer::prepare_for_forward,

View File

@ -97,7 +97,8 @@ Reducer::Reducer(
bool gradient_as_bucket_view,
std::unordered_map<size_t, std::string> param_names,
int64_t first_bucket_bytes_cap,
bool skip_all_reduce_unused_params)
bool skip_all_reduce_unused_params,
bool use_python_reducer)
: params_(std::move(params)),
process_group_(std::move(process_group)),
expect_sparse_gradients_(std::move(expect_sparse_gradients)),
@ -121,7 +122,8 @@ Reducer::Reducer(
comm_hook_(nullptr),
ddp_debug_level_(debug_level()),
param_names_(std::move(param_names)),
first_bucket_bytes_cap_(first_bucket_bytes_cap) {
first_bucket_bytes_cap_(first_bucket_bytes_cap),
use_python_reducer_(use_python_reducer) {
C10_LOG_API_USAGE_ONCE("torch.distributed.ddp.reducer");
TORCH_INTERNAL_ASSERT(!params_.empty(), "Expected at least one parameter.");
@ -199,8 +201,9 @@ Reducer::Reducer(
this->autograd_hook(variable_index);
return outputs;
},
[=](torch::autograd::CompiledNodeArgs& args) {
TORCH_INTERNAL_ASSERT(
[this](torch::autograd::CompiledNodeArgs& args) {
TORCH_CHECK(
this->use_python_reducer_,
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".");
})),
grad_accumulator);

View File

@ -58,7 +58,8 @@ class TORCH_API Reducer {
bool gradient_as_bucket_view,
std::unordered_map<size_t, std::string> param_names,
int64_t first_bucket_bytes_cap,
bool skip_all_reduce_unused_params);
bool skip_all_reduce_unused_params,
bool use_python_reducer);
~Reducer() noexcept(false);
@ -562,6 +563,9 @@ class TORCH_API Reducer {
void checkAndRaiseMarkedTwiceError(size_t curVariableIndex);
// Retrieves parameter corresponding to the given VariableIndex.
at::Tensor& get_param_from_index(size_t index);
// Python reducer keeps C++ reducer initialized. To remove this flag,
// we need to refactor the DDP wrapper's initilization.
bool use_python_reducer_;
// Cached bucket index to model parameter mapping. Populated after buckets
// are rebuilt after which this mapping is static.

View File

@ -657,6 +657,9 @@ class DistributedDataParallel(Module, Joinable):
):
super().__init__()
Joinable.__init__(self)
self._use_python_reducer = (
torch._dynamo.utils.get_optimize_ddp_mode() == "python_reducer"
)
self.logger: Optional[dist.Logger] = None
if bool(delay_all_reduce_named_params is not None) != bool(
param_to_hook_all_reduce is not None
@ -915,8 +918,6 @@ class DistributedDataParallel(Module, Joinable):
# True. The hooks will be deregistered if compiled_autograd is not
# enabled.
self._accum_grad_hooks: list[RemovableHandle] = []
optimize_ddp = torch._dynamo.utils.get_optimize_ddp_mode()
self._use_python_reducer = optimize_ddp == "python_reducer"
if self._use_python_reducer:
torch._inductor.config._fuse_ddp_communication = True
torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
@ -1228,6 +1229,7 @@ class DistributedDataParallel(Module, Joinable):
else self.bucket_bytes_cap
),
self.skip_all_reduce_unused_params,
self._use_python_reducer,
)
self.logger = dist.Logger(self.reducer)