[RFC] Add an API to remove autograd hooks from DDP (#96490)

Summary:
When creating a new DDP instance for the same model when an old DDP instance existed, the autograd hooks from the old DDP instance might not be cleared. Also, relying on python gc to clear out old autograd hooks is fragile and may not work 100% of the time.

As a result, in this PR I'm adding a way to explicitly remove these hooks from DDP

Test Plan:
Unit test added

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96490
Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
This commit is contained in:
Pritam Damania
2023-03-21 02:56:16 +00:00
committed by PyTorch MergeBot
parent fa82080016
commit e20e5f5578
6 changed files with 81 additions and 13 deletions

View File

@ -60,6 +60,7 @@ class Reducer:
def _set_static_graph(self) -> None: ...
def _run_comm_hook(self, bucket: GradBucket) -> Future: ...
def set_logger(self, logger: Logger) -> None: ...
def _remove_autograd_hooks(self) -> None: ...
class DDPLoggingData:
strs_map: Dict[str, str]

View File

@ -517,7 +517,11 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO
const std::shared_ptr<::c10d::Logger> logger) {
std::weak_ptr<::c10d::Logger> logger_weakref = logger;
reducer.set_logger(logger_weakref);
});
})
.def(
"_remove_autograd_hooks",
[](::c10d::Reducer& reducer) { reducer.remove_autograd_hooks(); },
py::call_guard<py::gil_scoped_release>());
shared_ptr_class_<::c10d::Logger>(module, "Logger")
.def(

View File

@ -259,18 +259,7 @@ Reducer::Reducer(
// be specified by calling `register_builtin_comm_hook` from Python API.
Reducer::~Reducer() noexcept(false) {
// Remove all hooks on variables registered by this Reducer. This is necessary
// to make DDP failure recoverable. Otherwise, multiple Reducer instances
// (from recoveries) will add their hooks to the original model, and those
// hooks will try to invoke methods on a deleted Reducer objects.
for (auto& hook : hooks_) {
auto& key = hook.first;
auto& grad_accumulator = hook.second;
TORCH_INTERNAL_ASSERT(
grad_accumulator->del_post_hook(key),
"Reducer attempts to delete a non-existing hook.");
}
remove_autograd_hooks();
}
bool Reducer::dynamic_graph_find_unused() {
@ -2240,4 +2229,20 @@ void verify_params_across_processes(
}
}
void Reducer::remove_autograd_hooks() {
// Remove all hooks on variables registered by this Reducer. This is necessary
// to make DDP failure recoverable. Otherwise, multiple Reducer instances
// (from recoveries) will add their hooks to the original model, and those
// hooks will try to invoke methods on a deleted Reducer objects.
for (auto& hook : hooks_) {
auto& key = hook.first;
auto& grad_accumulator = hook.second;
TORCH_INTERNAL_ASSERT(
grad_accumulator->del_post_hook(key),
"Reducer attempts to delete a non-existing hook.");
}
hooks_.clear();
}
} // namespace c10d

View File

@ -177,6 +177,8 @@ class TORCH_API Reducer {
// current iteration, which means unused params set has not changed.
bool ddp_graph_static();
void remove_autograd_hooks();
protected:
// Forward declaration.
struct Bucket;

View File

@ -2232,3 +2232,6 @@ class DistributedDataParallel(Module, Joinable):
"unused parameters will not change during training loop while calling "
"`_set_static_graph`."
)
def _remove_autograd_hooks(self):
self.reducer._remove_autograd_hooks()

View File

@ -9573,6 +9573,59 @@ class DistributedTest:
loss_hook.backward()
loss_no_hook.backward()
@skip_if_lt_x_gpu(2)
@skip_but_pass_in_sandcastle_if(
BACKEND not in DistTestCases.backend_feature["ddp"],
f"The {BACKEND} backend does not support DistributedDataParallel",
)
def test_ddp_remove_autograd_hooks(self):
class SimulateError(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad_output):
raise RuntimeError()
class MyModel(nn.Module):
def __init__(self, device):
super(MyModel, self).__init__()
self.error = True
self.fc1 = nn.Linear(10, 10).cuda(device)
def forward(self, inp):
if self.error:
return self.fc1(SimulateError.apply(inp))
else:
return self.fc1(inp)
# Run with error to trigger backward pass that marks fc1 as being marked
# ready. If we don't remove autograd hooks before running below it would
# fail on the old autograd hook.
model = MyModel(self.rank)
input = torch.rand(10, 10, requires_grad=True).cuda(self.rank)
model_ddp1 = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[self.rank],
)
with self.assertRaises(RuntimeError):
model_ddp1(input).sum().backward()
# Remove autograd hooks on old instance.
model_ddp1._remove_autograd_hooks()
# Try another DDP instance without error now.
model.error = False
model_ddp2 = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[self.rank],
)
model_ddp2(input).sum().backward()
@skip_if_lt_x_gpu(2)
@skip_but_pass_in_sandcastle_if(
BACKEND not in DistTestCases.backend_feature["ddp"],