mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Fully deprecate variadic inputs of checkpoint_sequential (#25985)
Summary: To support variadic inputs of `checkpoint_sequential` was deprecated at https://github.com/pytorch/pytorch/issues/21006. This case should be warned with `DeprecationWarning` for PyTorch 1.2, but it should be simply failed with `TypeError` since PyTorch 1.3. This patch removes the `DeprecationWarning` for PyTorch 1.2. Pull Request resolved: https://github.com/pytorch/pytorch/pull/25985 Differential Revision: D18809875 Pulled By: albanD fbshipit-source-id: e84dd8629c04979c4b2dc63e8ada94292e8cedd0
This commit is contained in:
committed by
Facebook Github Bot
parent
2607772959
commit
fa251cfd97
@ -44,41 +44,37 @@ class TestCheckpoint(TestCase):
|
||||
model,
|
||||
module_lists_to_compare,
|
||||
num_chunks,
|
||||
*inputs
|
||||
input,
|
||||
):
|
||||
|
||||
# not checkpointed
|
||||
if not isinstance(inputs, tuple):
|
||||
inputs = (inputs,)
|
||||
out = model(*inputs)
|
||||
out_not_checkpointed = out.data.clone()
|
||||
out = model(input)
|
||||
out_not_checkpointed = out.detach().clone()
|
||||
model.zero_grad()
|
||||
out.sum().backward()
|
||||
grad_not_checkpointed = {
|
||||
name: param.grad.data.clone()
|
||||
name: param.grad.detach().clone()
|
||||
for name, param in model.named_parameters()
|
||||
}
|
||||
input_grad_not_checkpointed = [i.grad.data.clone() for i in inputs]
|
||||
input_grad_not_checkpointed = input.grad.detach().clone()
|
||||
for model_to_compare in module_lists_to_compare:
|
||||
# checkpointed model by passing list of modules
|
||||
detached_inputs = [i.detach() for i in inputs]
|
||||
for detached in detached_inputs:
|
||||
detached.requires_grad = True
|
||||
detached = input.detach()
|
||||
detached.requires_grad = True
|
||||
|
||||
# pass list of modules to checkpoint
|
||||
out = checkpoint_sequential(model_to_compare, num_chunks, *detached_inputs)
|
||||
out_checkpointed = out.data.clone()
|
||||
out = checkpoint_sequential(model_to_compare, num_chunks, detached)
|
||||
out_checkpointed = out.detach().clone()
|
||||
model.zero_grad()
|
||||
out.sum().backward()
|
||||
grad_checkpointed = {
|
||||
name: param.grad.data.clone()
|
||||
name: param.grad.detach().clone()
|
||||
for name, param in model.named_parameters()
|
||||
}
|
||||
input_grad_checkpointed = [d.grad.data.clone() for d in detached_inputs]
|
||||
input_grad_checkpointed = detached.grad.detach().clone()
|
||||
# compare outputs as well as the gradients of input and parameters
|
||||
self.assertEqual(out_checkpointed, out_not_checkpointed)
|
||||
for i, j in zip(input_grad_not_checkpointed, input_grad_checkpointed):
|
||||
self.assertEqual(i, j)
|
||||
self.assertEqual(input_grad_not_checkpointed, input_grad_checkpointed)
|
||||
for name in grad_checkpointed:
|
||||
self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
|
||||
|
||||
@ -152,12 +148,12 @@ class TestCheckpoint(TestCase):
|
||||
torch.randn(1, 100, requires_grad=True)
|
||||
)
|
||||
|
||||
def test_checkpoint_module_list_multiple_args(self):
|
||||
def test_checkpoint_module_list(self):
|
||||
class ModuleListNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(ModuleListNet, self).__init__()
|
||||
module_list = [
|
||||
nn.Bilinear(100, 60, 50),
|
||||
nn.Linear(100, 50),
|
||||
nn.ReLU(),
|
||||
nn.Linear(50, 20),
|
||||
nn.ReLU(),
|
||||
@ -166,26 +162,19 @@ class TestCheckpoint(TestCase):
|
||||
]
|
||||
self.module_list = nn.ModuleList(module_list)
|
||||
|
||||
def forward(self, *inputs):
|
||||
def forward(self, input):
|
||||
for layer in self.module_list:
|
||||
if isinstance(inputs, tuple):
|
||||
inputs = layer(*inputs)
|
||||
else:
|
||||
inputs = layer(inputs)
|
||||
return inputs
|
||||
input = layer(input)
|
||||
return input
|
||||
|
||||
model = ModuleListNet()
|
||||
|
||||
# Compare uncheckpointed model with its checkpointed counterparts
|
||||
# In addition to running checkpoint_sequential on the nn.ModuleList
|
||||
# instance, we also run the function on the list of functions within
|
||||
# the ModuleList.
|
||||
# Compare uncheckpointed model with its checkpointed counterparts.
|
||||
self._check_checkpoint_sequential(
|
||||
model,
|
||||
[list(model.module_list.children()), model.module_list],
|
||||
2,
|
||||
torch.randn(1, 100, requires_grad=True),
|
||||
torch.randn(1, 60, requires_grad=True)
|
||||
)
|
||||
|
||||
def test_checkpoint_sequential_deprecated_multiple_args(self):
|
||||
@ -197,11 +186,8 @@ class TestCheckpoint(TestCase):
|
||||
a = torch.randn(1, 100, requires_grad=True)
|
||||
b = torch.randn(1, 100, requires_grad=True)
|
||||
|
||||
self.assertWarnsRegex(
|
||||
lambda: checkpoint_sequential(model, 1, a, b),
|
||||
'deprecated',
|
||||
'checkpoint_sequential with multiple args should be deprecated',
|
||||
)
|
||||
with self.assertRaises(TypeError):
|
||||
checkpoint_sequential(model, 1, a, b)
|
||||
|
||||
def test_checkpoint_sequential_deprecated_no_args(self):
|
||||
class Noop(nn.Module):
|
||||
@ -210,11 +196,8 @@ class TestCheckpoint(TestCase):
|
||||
|
||||
model = nn.Sequential(Noop())
|
||||
|
||||
self.assertWarnsRegex(
|
||||
lambda: checkpoint_sequential(model, 1),
|
||||
'deprecated',
|
||||
'checkpoint_sequential with no args should be deprecated',
|
||||
)
|
||||
with self.assertRaises(TypeError):
|
||||
checkpoint_sequential(model, 1)
|
||||
|
||||
def test_checkpoint_rng_cpu(self):
|
||||
for _ in range(5):
|
||||
|
||||
Reference in New Issue
Block a user