mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
[RFC] Add an API to remove autograd hooks from DDP (#96490)
Summary: When creating a new DDP instance for the same model when an old DDP instance existed, the autograd hooks from the old DDP instance might not be cleared. Also, relying on python gc to clear out old autograd hooks is fragile and may not work 100% of the time. As a result, in this PR I'm adding a way to explicitly remove these hooks from DDP Test Plan: Unit test added Pull Request resolved: https://github.com/pytorch/pytorch/pull/96490 Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
This commit is contained in:
committed by
PyTorch MergeBot
parent
fa82080016
commit
e20e5f5578
@ -9573,6 +9573,59 @@ class DistributedTest:
|
||||
loss_hook.backward()
|
||||
loss_no_hook.backward()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND not in DistTestCases.backend_feature["ddp"],
|
||||
f"The {BACKEND} backend does not support DistributedDataParallel",
|
||||
)
|
||||
def test_ddp_remove_autograd_hooks(self):
|
||||
|
||||
class SimulateError(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
raise RuntimeError()
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self, device):
|
||||
super(MyModel, self).__init__()
|
||||
self.error = True
|
||||
self.fc1 = nn.Linear(10, 10).cuda(device)
|
||||
|
||||
def forward(self, inp):
|
||||
if self.error:
|
||||
return self.fc1(SimulateError.apply(inp))
|
||||
else:
|
||||
return self.fc1(inp)
|
||||
|
||||
|
||||
# Run with error to trigger backward pass that marks fc1 as being marked
|
||||
# ready. If we don't remove autograd hooks before running below it would
|
||||
# fail on the old autograd hook.
|
||||
model = MyModel(self.rank)
|
||||
input = torch.rand(10, 10, requires_grad=True).cuda(self.rank)
|
||||
model_ddp1 = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[self.rank],
|
||||
)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
model_ddp1(input).sum().backward()
|
||||
|
||||
# Remove autograd hooks on old instance.
|
||||
model_ddp1._remove_autograd_hooks()
|
||||
|
||||
# Try another DDP instance without error now.
|
||||
model.error = False
|
||||
model_ddp2 = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[self.rank],
|
||||
)
|
||||
model_ddp2(input).sum().backward()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND not in DistTestCases.backend_feature["ddp"],
|
||||
|
||||
Reference in New Issue
Block a user