Fully deprecate variadic inputs of checkpoint_sequential (#25985)

Summary:
To support variadic inputs of `checkpoint_sequential` was deprecated at https://github.com/pytorch/pytorch/issues/21006. This case should be warned with `DeprecationWarning` for PyTorch 1.2, but it should be simply failed with `TypeError` since PyTorch 1.3. This patch removes the `DeprecationWarning` for PyTorch 1.2.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25985

Differential Revision: D18809875

Pulled By: albanD

fbshipit-source-id: e84dd8629c04979c4b2dc63e8ada94292e8cedd0
This commit is contained in:
Heungsub Hans Lee
2019-12-05 09:21:47 -08:00
committed by Facebook Github Bot
parent 2607772959
commit fa251cfd97
2 changed files with 35 additions and 68 deletions

View File

@ -44,41 +44,37 @@ class TestCheckpoint(TestCase):
model,
module_lists_to_compare,
num_chunks,
*inputs
input,
):
# not checkpointed
if not isinstance(inputs, tuple):
inputs = (inputs,)
out = model(*inputs)
out_not_checkpointed = out.data.clone()
out = model(input)
out_not_checkpointed = out.detach().clone()
model.zero_grad()
out.sum().backward()
grad_not_checkpointed = {
name: param.grad.data.clone()
name: param.grad.detach().clone()
for name, param in model.named_parameters()
}
input_grad_not_checkpointed = [i.grad.data.clone() for i in inputs]
input_grad_not_checkpointed = input.grad.detach().clone()
for model_to_compare in module_lists_to_compare:
# checkpointed model by passing list of modules
detached_inputs = [i.detach() for i in inputs]
for detached in detached_inputs:
detached.requires_grad = True
detached = input.detach()
detached.requires_grad = True
# pass list of modules to checkpoint
out = checkpoint_sequential(model_to_compare, num_chunks, *detached_inputs)
out_checkpointed = out.data.clone()
out = checkpoint_sequential(model_to_compare, num_chunks, detached)
out_checkpointed = out.detach().clone()
model.zero_grad()
out.sum().backward()
grad_checkpointed = {
name: param.grad.data.clone()
name: param.grad.detach().clone()
for name, param in model.named_parameters()
}
input_grad_checkpointed = [d.grad.data.clone() for d in detached_inputs]
input_grad_checkpointed = detached.grad.detach().clone()
# compare outputs as well as the gradients of input and parameters
self.assertEqual(out_checkpointed, out_not_checkpointed)
for i, j in zip(input_grad_not_checkpointed, input_grad_checkpointed):
self.assertEqual(i, j)
self.assertEqual(input_grad_not_checkpointed, input_grad_checkpointed)
for name in grad_checkpointed:
self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])
@ -152,12 +148,12 @@ class TestCheckpoint(TestCase):
torch.randn(1, 100, requires_grad=True)
)
def test_checkpoint_module_list_multiple_args(self):
def test_checkpoint_module_list(self):
class ModuleListNet(nn.Module):
def __init__(self):
super(ModuleListNet, self).__init__()
module_list = [
nn.Bilinear(100, 60, 50),
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, 20),
nn.ReLU(),
@ -166,26 +162,19 @@ class TestCheckpoint(TestCase):
]
self.module_list = nn.ModuleList(module_list)
def forward(self, *inputs):
def forward(self, input):
for layer in self.module_list:
if isinstance(inputs, tuple):
inputs = layer(*inputs)
else:
inputs = layer(inputs)
return inputs
input = layer(input)
return input
model = ModuleListNet()
# Compare uncheckpointed model with its checkpointed counterparts
# In addition to running checkpoint_sequential on the nn.ModuleList
# instance, we also run the function on the list of functions within
# the ModuleList.
# Compare uncheckpointed model with its checkpointed counterparts.
self._check_checkpoint_sequential(
model,
[list(model.module_list.children()), model.module_list],
2,
torch.randn(1, 100, requires_grad=True),
torch.randn(1, 60, requires_grad=True)
)
def test_checkpoint_sequential_deprecated_multiple_args(self):
@ -197,11 +186,8 @@ class TestCheckpoint(TestCase):
a = torch.randn(1, 100, requires_grad=True)
b = torch.randn(1, 100, requires_grad=True)
self.assertWarnsRegex(
lambda: checkpoint_sequential(model, 1, a, b),
'deprecated',
'checkpoint_sequential with multiple args should be deprecated',
)
with self.assertRaises(TypeError):
checkpoint_sequential(model, 1, a, b)
def test_checkpoint_sequential_deprecated_no_args(self):
class Noop(nn.Module):
@ -210,11 +196,8 @@ class TestCheckpoint(TestCase):
model = nn.Sequential(Noop())
self.assertWarnsRegex(
lambda: checkpoint_sequential(model, 1),
'deprecated',
'checkpoint_sequential with no args should be deprecated',
)
with self.assertRaises(TypeError):
checkpoint_sequential(model, 1)
def test_checkpoint_rng_cpu(self):
for _ in range(5):