mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[RFC] Add an API to remove autograd hooks from DDP (#96490)
Summary: When creating a new DDP instance for the same model when an old DDP instance existed, the autograd hooks from the old DDP instance might not be cleared. Also, relying on python gc to clear out old autograd hooks is fragile and may not work 100% of the time. As a result, in this PR I'm adding a way to explicitly remove these hooks from DDP Test Plan: Unit test added Pull Request resolved: https://github.com/pytorch/pytorch/pull/96490 Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
This commit is contained in:
committed by
PyTorch MergeBot
parent
fa82080016
commit
e20e5f5578
@ -60,6 +60,7 @@ class Reducer:
|
||||
def _set_static_graph(self) -> None: ...
|
||||
def _run_comm_hook(self, bucket: GradBucket) -> Future: ...
|
||||
def set_logger(self, logger: Logger) -> None: ...
|
||||
def _remove_autograd_hooks(self) -> None: ...
|
||||
|
||||
class DDPLoggingData:
|
||||
strs_map: Dict[str, str]
|
||||
|
@ -517,7 +517,11 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
|
||||
const std::shared_ptr<::c10d::Logger> logger) {
|
||||
std::weak_ptr<::c10d::Logger> logger_weakref = logger;
|
||||
reducer.set_logger(logger_weakref);
|
||||
});
|
||||
})
|
||||
.def(
|
||||
"_remove_autograd_hooks",
|
||||
[](::c10d::Reducer& reducer) { reducer.remove_autograd_hooks(); },
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
shared_ptr_class_<::c10d::Logger>(module, "Logger")
|
||||
.def(
|
||||
|
@ -259,18 +259,7 @@ Reducer::Reducer(
|
||||
// be specified by calling `register_builtin_comm_hook` from Python API.
|
||||
|
||||
Reducer::~Reducer() noexcept(false) {
|
||||
// Remove all hooks on variables registered by this Reducer. This is necessary
|
||||
// to make DDP failure recoverable. Otherwise, multiple Reducer instances
|
||||
// (from recoveries) will add their hooks to the original model, and those
|
||||
// hooks will try to invoke methods on a deleted Reducer objects.
|
||||
for (auto& hook : hooks_) {
|
||||
auto& key = hook.first;
|
||||
auto& grad_accumulator = hook.second;
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
grad_accumulator->del_post_hook(key),
|
||||
"Reducer attempts to delete a non-existing hook.");
|
||||
}
|
||||
remove_autograd_hooks();
|
||||
}
|
||||
|
||||
bool Reducer::dynamic_graph_find_unused() {
|
||||
@ -2240,4 +2229,20 @@ void verify_params_across_processes(
|
||||
}
|
||||
}
|
||||
|
||||
void Reducer::remove_autograd_hooks() {
|
||||
// Remove all hooks on variables registered by this Reducer. This is necessary
|
||||
// to make DDP failure recoverable. Otherwise, multiple Reducer instances
|
||||
// (from recoveries) will add their hooks to the original model, and those
|
||||
// hooks will try to invoke methods on a deleted Reducer objects.
|
||||
for (auto& hook : hooks_) {
|
||||
auto& key = hook.first;
|
||||
auto& grad_accumulator = hook.second;
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
grad_accumulator->del_post_hook(key),
|
||||
"Reducer attempts to delete a non-existing hook.");
|
||||
}
|
||||
hooks_.clear();
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
@ -177,6 +177,8 @@ class TORCH_API Reducer {
|
||||
// current iteration, which means unused params set has not changed.
|
||||
bool ddp_graph_static();
|
||||
|
||||
void remove_autograd_hooks();
|
||||
|
||||
protected:
|
||||
// Forward declaration.
|
||||
struct Bucket;
|
||||
|
@ -2232,3 +2232,6 @@ class DistributedDataParallel(Module, Joinable):
|
||||
"unused parameters will not change during training loop while calling "
|
||||
"`_set_static_graph`."
|
||||
)
|
||||
|
||||
def _remove_autograd_hooks(self):
|
||||
self.reducer._remove_autograd_hooks()
|
||||
|
@ -9573,6 +9573,59 @@ class DistributedTest:
|
||||
loss_hook.backward()
|
||||
loss_no_hook.backward()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND not in DistTestCases.backend_feature["ddp"],
|
||||
f"The {BACKEND} backend does not support DistributedDataParallel",
|
||||
)
|
||||
def test_ddp_remove_autograd_hooks(self):
|
||||
|
||||
class SimulateError(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise RuntimeError()
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self, device):
|
||||
super(MyModel, self).__init__()
|
||||
self.error = True
|
||||
self.fc1 = nn.Linear(10, 10).cuda(device)
|
||||
|
||||
def forward(self, inp):
|
||||
if self.error:
|
||||
return self.fc1(SimulateError.apply(inp))
|
||||
else:
|
||||
return self.fc1(inp)
|
||||
|
||||
|
||||
# Run with error to trigger backward pass that marks fc1 as being marked
|
||||
# ready. If we don't remove autograd hooks before running below it would
|
||||
# fail on the old autograd hook.
|
||||
model = MyModel(self.rank)
|
||||
input = torch.rand(10, 10, requires_grad=True).cuda(self.rank)
|
||||
model_ddp1 = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[self.rank],
|
||||
)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
model_ddp1(input).sum().backward()
|
||||
|
||||
# Remove autograd hooks on old instance.
|
||||
model_ddp1._remove_autograd_hooks()
|
||||
|
||||
# Try another DDP instance without error now.
|
||||
model.error = False
|
||||
model_ddp2 = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[self.rank],
|
||||
)
|
||||
model_ddp2(input).sum().backward()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND not in DistTestCases.backend_feature["ddp"],
|
||||
|
Reference in New Issue
Block a user