mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ddp] propagate use_python_reducer to C++ reducer (#152735)
C++ Reducer is silently incorrect under CA, its implementation is no-oping the collective. I'm guessing that it was no-op'd because in DDP + python reducer, the C++ reducer is still being initialized. Pull Request resolved: https://github.com/pytorch/pytorch/pull/152735 Approved by: https://github.com/fegin ghstack dependencies: #153300, #152689
This commit is contained in:
committed by
PyTorch MergeBot
parent
1b4749f748
commit
d1f1ff8610
@ -19,6 +19,7 @@ from string import Template
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import _inductor as inductor
|
||||
@ -30,6 +31,7 @@ from torch._dynamo.utils import counters
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.nn.attention.flex_attention import flex_attention
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
ops,
|
||||
@ -4161,6 +4163,54 @@ class CompiledAutograd1(torch.nn.Module):
|
||||
first, second, third, fourth = fn(eager(), aot_eager())
|
||||
self.assertIsNone(third)
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.distributed.is_available(),
|
||||
"FakePG relies on distributed build",
|
||||
)
|
||||
def test_ddp_cpp_reducer_error(self):
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
store = FakeStore()
|
||||
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
||||
try:
|
||||
model = torch.nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10))
|
||||
model = DDP(model)
|
||||
inputs = torch.randn(10, 10)
|
||||
loss = model(inputs).sum()
|
||||
with compiled_autograd._enable(compiler_fn), self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
(
|
||||
r"Compiled autograd is not compatible with C\+\+ DDP Reducer, "
|
||||
r'please use torch._dynamo.config.optimize_ddp="python_reducer"'
|
||||
),
|
||||
):
|
||||
loss.backward()
|
||||
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.distributed.is_available(),
|
||||
"FakePG relies on distributed build",
|
||||
)
|
||||
@config.patch(optimize_ddp="python_reducer")
|
||||
def test_ddp_python_reducer(self):
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
store = FakeStore()
|
||||
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
||||
try:
|
||||
model = torch.nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10))
|
||||
model = DDP(model)
|
||||
inputs = torch.randn(10, 10)
|
||||
loss = model(inputs).sum()
|
||||
with compiled_autograd._enable(compiler_fn):
|
||||
# no error expected
|
||||
loss.backward()
|
||||
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def load_test_module(name):
|
||||
testdir = Path(__file__).absolute().parent.parent
|
||||
|
@ -51,6 +51,7 @@ class Reducer:
|
||||
param_to_name_mapping: dict[int, str] = ...,
|
||||
first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
|
||||
skip_all_reduce_unused_params: bool = ...,
|
||||
use_python_reducer: bool = ...,
|
||||
) -> None: ...
|
||||
def prepare_for_forward(self) -> None: ...
|
||||
def prepare_for_backward(self, output: list[Tensor]) -> None: ...
|
||||
|
@ -27,7 +27,12 @@ class LambdaPostHook : public torch::autograd::FunctionPostHook {
|
||||
return fn_(outputs, inputs);
|
||||
}
|
||||
|
||||
void compiled_args(CompiledNodeArgs& args) const override {}
|
||||
void compiled_args(CompiledNodeArgs& args) const override {
|
||||
if (compiled_fn_ != nullptr) {
|
||||
return compiled_fn_(args);
|
||||
}
|
||||
return FunctionPostHook::compiled_args(args);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::function<variable_list(const variable_list&, const variable_list&)> fn_;
|
||||
|
@ -560,7 +560,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
|
||||
bool gradient_as_bucket_view,
|
||||
std::unordered_map<size_t, std::string> param_to_name_mapping,
|
||||
int64_t first_bucket_bytes_cap,
|
||||
bool skip_all_reduce_unused_params) {
|
||||
bool skip_all_reduce_unused_params,
|
||||
bool use_python_reducer) {
|
||||
// gil_scoped_release is not safe as a call_guard in init.
|
||||
// https://github.com/pybind/pybind11/issues/5473
|
||||
py::gil_scoped_release nogil{};
|
||||
@ -575,7 +576,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
|
||||
gradient_as_bucket_view,
|
||||
std::move(param_to_name_mapping),
|
||||
first_bucket_bytes_cap,
|
||||
skip_all_reduce_unused_params);
|
||||
skip_all_reduce_unused_params,
|
||||
use_python_reducer);
|
||||
}),
|
||||
py::arg("params"),
|
||||
py::arg("bucket_indices"),
|
||||
@ -588,7 +590,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
|
||||
py::arg("param_to_name_mapping") =
|
||||
std::unordered_map<size_t, std::string>(),
|
||||
py::arg("first_bucket_bytes_cap") = ::c10d::kDefaultFirstBucketBytes,
|
||||
py::arg("skip_all_reduce_unused_params") = false)
|
||||
py::arg("skip_all_reduce_unused_params") = false,
|
||||
py::arg("use_python_reducer") = false)
|
||||
.def(
|
||||
"prepare_for_forward",
|
||||
&::c10d::Reducer::prepare_for_forward,
|
||||
|
@ -97,7 +97,8 @@ Reducer::Reducer(
|
||||
bool gradient_as_bucket_view,
|
||||
std::unordered_map<size_t, std::string> param_names,
|
||||
int64_t first_bucket_bytes_cap,
|
||||
bool skip_all_reduce_unused_params)
|
||||
bool skip_all_reduce_unused_params,
|
||||
bool use_python_reducer)
|
||||
: params_(std::move(params)),
|
||||
process_group_(std::move(process_group)),
|
||||
expect_sparse_gradients_(std::move(expect_sparse_gradients)),
|
||||
@ -121,7 +122,8 @@ Reducer::Reducer(
|
||||
comm_hook_(nullptr),
|
||||
ddp_debug_level_(debug_level()),
|
||||
param_names_(std::move(param_names)),
|
||||
first_bucket_bytes_cap_(first_bucket_bytes_cap) {
|
||||
first_bucket_bytes_cap_(first_bucket_bytes_cap),
|
||||
use_python_reducer_(use_python_reducer) {
|
||||
C10_LOG_API_USAGE_ONCE("torch.distributed.ddp.reducer");
|
||||
TORCH_INTERNAL_ASSERT(!params_.empty(), "Expected at least one parameter.");
|
||||
|
||||
@ -199,8 +201,9 @@ Reducer::Reducer(
|
||||
this->autograd_hook(variable_index);
|
||||
return outputs;
|
||||
},
|
||||
[=](torch::autograd::CompiledNodeArgs& args) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
[this](torch::autograd::CompiledNodeArgs& args) {
|
||||
TORCH_CHECK(
|
||||
this->use_python_reducer_,
|
||||
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".");
|
||||
})),
|
||||
grad_accumulator);
|
||||
|
@ -58,7 +58,8 @@ class TORCH_API Reducer {
|
||||
bool gradient_as_bucket_view,
|
||||
std::unordered_map<size_t, std::string> param_names,
|
||||
int64_t first_bucket_bytes_cap,
|
||||
bool skip_all_reduce_unused_params);
|
||||
bool skip_all_reduce_unused_params,
|
||||
bool use_python_reducer);
|
||||
|
||||
~Reducer() noexcept(false);
|
||||
|
||||
@ -562,6 +563,9 @@ class TORCH_API Reducer {
|
||||
void checkAndRaiseMarkedTwiceError(size_t curVariableIndex);
|
||||
// Retrieves parameter corresponding to the given VariableIndex.
|
||||
at::Tensor& get_param_from_index(size_t index);
|
||||
// Python reducer keeps C++ reducer initialized. To remove this flag,
|
||||
// we need to refactor the DDP wrapper's initilization.
|
||||
bool use_python_reducer_;
|
||||
|
||||
// Cached bucket index to model parameter mapping. Populated after buckets
|
||||
// are rebuilt after which this mapping is static.
|
||||
|
@ -657,6 +657,9 @@ class DistributedDataParallel(Module, Joinable):
|
||||
):
|
||||
super().__init__()
|
||||
Joinable.__init__(self)
|
||||
self._use_python_reducer = (
|
||||
torch._dynamo.utils.get_optimize_ddp_mode() == "python_reducer"
|
||||
)
|
||||
self.logger: Optional[dist.Logger] = None
|
||||
if bool(delay_all_reduce_named_params is not None) != bool(
|
||||
param_to_hook_all_reduce is not None
|
||||
@ -915,8 +918,6 @@ class DistributedDataParallel(Module, Joinable):
|
||||
# True. The hooks will be deregistered if compiled_autograd is not
|
||||
# enabled.
|
||||
self._accum_grad_hooks: list[RemovableHandle] = []
|
||||
optimize_ddp = torch._dynamo.utils.get_optimize_ddp_mode()
|
||||
self._use_python_reducer = optimize_ddp == "python_reducer"
|
||||
if self._use_python_reducer:
|
||||
torch._inductor.config._fuse_ddp_communication = True
|
||||
torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
|
||||
@ -1228,6 +1229,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
else self.bucket_bytes_cap
|
||||
),
|
||||
self.skip_all_reduce_unused_params,
|
||||
self._use_python_reducer,
|
||||
)
|
||||
|
||||
self.logger = dist.Logger(self.reducer)
|
||||
|
Reference in New Issue
Block a user