[RFC] Add an API to remove autograd hooks from DDP (#96490)

Summary:
When creating a new DDP instance for the same model when an old DDP instance existed, the autograd hooks from the old DDP instance might not be cleared. Also, relying on python gc to clear out old autograd hooks is fragile and may not work 100% of the time.

As a result, in this PR I'm adding a way to explicitly remove these hooks from DDP

Test Plan:
Unit test added

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96490
Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
This commit is contained in:
Pritam Damania
2023-03-21 02:56:16 +00:00
committed by PyTorch MergeBot
parent fa82080016
commit e20e5f5578
6 changed files with 81 additions and 13 deletions

View File

@ -9573,6 +9573,59 @@ class DistributedTest:
loss_hook.backward()
loss_no_hook.backward()
@skip_if_lt_x_gpu(2)
@skip_but_pass_in_sandcastle_if(
BACKEND not in DistTestCases.backend_feature["ddp"],
f"The {BACKEND} backend does not support DistributedDataParallel",
)
def test_ddp_remove_autograd_hooks(self):
class SimulateError(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad_output):
raise RuntimeError()
class MyModel(nn.Module):
def __init__(self, device):
super(MyModel, self).__init__()
self.error = True
self.fc1 = nn.Linear(10, 10).cuda(device)
def forward(self, inp):
if self.error:
return self.fc1(SimulateError.apply(inp))
else:
return self.fc1(inp)
# Run with error to trigger backward pass that marks fc1 as being marked
# ready. If we don't remove autograd hooks before running below it would
# fail on the old autograd hook.
model = MyModel(self.rank)
input = torch.rand(10, 10, requires_grad=True).cuda(self.rank)
model_ddp1 = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[self.rank],
)
with self.assertRaises(RuntimeError):
model_ddp1(input).sum().backward()
# Remove autograd hooks on old instance.
model_ddp1._remove_autograd_hooks()
# Try another DDP instance without error now.
model.error = False
model_ddp2 = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[self.rank],
)
model_ddp2(input).sum().backward()
@skip_if_lt_x_gpu(2)
@skip_but_pass_in_sandcastle_if(
BACKEND not in DistTestCases.backend_feature["ddp"],