From fa251cfd97d91320ca3670184d97737b5cde9d88 Mon Sep 17 00:00:00 2001 From: Heungsub Hans Lee Date: Thu, 5 Dec 2019 09:21:47 -0800 Subject: [PATCH] Fully deprecate variadic inputs of checkpoint_sequential (#25985) Summary: To support variadic inputs of `checkpoint_sequential` was deprecated at https://github.com/pytorch/pytorch/issues/21006. This case should be warned with `DeprecationWarning` for PyTorch 1.2, but it should be simply failed with `TypeError` since PyTorch 1.3. This patch removes the `DeprecationWarning` for PyTorch 1.2. Pull Request resolved: https://github.com/pytorch/pytorch/pull/25985 Differential Revision: D18809875 Pulled By: albanD fbshipit-source-id: e84dd8629c04979c4b2dc63e8ada94292e8cedd0 --- test/test_utils.py | 61 ++++++++++++++------------------------- torch/utils/checkpoint.py | 42 +++++++++------------------ 2 files changed, 35 insertions(+), 68 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 104ab28dd355..063a0ece0b13 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -44,41 +44,37 @@ class TestCheckpoint(TestCase): model, module_lists_to_compare, num_chunks, - *inputs + input, ): # not checkpointed - if not isinstance(inputs, tuple): - inputs = (inputs,) - out = model(*inputs) - out_not_checkpointed = out.data.clone() + out = model(input) + out_not_checkpointed = out.detach().clone() model.zero_grad() out.sum().backward() grad_not_checkpointed = { - name: param.grad.data.clone() + name: param.grad.detach().clone() for name, param in model.named_parameters() } - input_grad_not_checkpointed = [i.grad.data.clone() for i in inputs] + input_grad_not_checkpointed = input.grad.detach().clone() for model_to_compare in module_lists_to_compare: # checkpointed model by passing list of modules - detached_inputs = [i.detach() for i in inputs] - for detached in detached_inputs: - detached.requires_grad = True + detached = input.detach() + detached.requires_grad = True # pass list of modules to checkpoint - out = checkpoint_sequential(model_to_compare, num_chunks, *detached_inputs) - out_checkpointed = out.data.clone() + out = checkpoint_sequential(model_to_compare, num_chunks, detached) + out_checkpointed = out.detach().clone() model.zero_grad() out.sum().backward() grad_checkpointed = { - name: param.grad.data.clone() + name: param.grad.detach().clone() for name, param in model.named_parameters() } - input_grad_checkpointed = [d.grad.data.clone() for d in detached_inputs] + input_grad_checkpointed = detached.grad.detach().clone() # compare outputs as well as the gradients of input and parameters self.assertEqual(out_checkpointed, out_not_checkpointed) - for i, j in zip(input_grad_not_checkpointed, input_grad_checkpointed): - self.assertEqual(i, j) + self.assertEqual(input_grad_not_checkpointed, input_grad_checkpointed) for name in grad_checkpointed: self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name]) @@ -152,12 +148,12 @@ class TestCheckpoint(TestCase): torch.randn(1, 100, requires_grad=True) ) - def test_checkpoint_module_list_multiple_args(self): + def test_checkpoint_module_list(self): class ModuleListNet(nn.Module): def __init__(self): super(ModuleListNet, self).__init__() module_list = [ - nn.Bilinear(100, 60, 50), + nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 20), nn.ReLU(), @@ -166,26 +162,19 @@ class TestCheckpoint(TestCase): ] self.module_list = nn.ModuleList(module_list) - def forward(self, *inputs): + def forward(self, input): for layer in self.module_list: - if isinstance(inputs, tuple): - inputs = layer(*inputs) - else: - inputs = layer(inputs) - return inputs + input = layer(input) + return input model = ModuleListNet() - # Compare uncheckpointed model with its checkpointed counterparts - # In addition to running checkpoint_sequential on the nn.ModuleList - # instance, we also run the function on the list of functions within - # the ModuleList. + # Compare uncheckpointed model with its checkpointed counterparts. self._check_checkpoint_sequential( model, [list(model.module_list.children()), model.module_list], 2, torch.randn(1, 100, requires_grad=True), - torch.randn(1, 60, requires_grad=True) ) def test_checkpoint_sequential_deprecated_multiple_args(self): @@ -197,11 +186,8 @@ class TestCheckpoint(TestCase): a = torch.randn(1, 100, requires_grad=True) b = torch.randn(1, 100, requires_grad=True) - self.assertWarnsRegex( - lambda: checkpoint_sequential(model, 1, a, b), - 'deprecated', - 'checkpoint_sequential with multiple args should be deprecated', - ) + with self.assertRaises(TypeError): + checkpoint_sequential(model, 1, a, b) def test_checkpoint_sequential_deprecated_no_args(self): class Noop(nn.Module): @@ -210,11 +196,8 @@ class TestCheckpoint(TestCase): model = nn.Sequential(Noop()) - self.assertWarnsRegex( - lambda: checkpoint_sequential(model, 1), - 'deprecated', - 'checkpoint_sequential with no args should be deprecated', - ) + with self.assertRaises(TypeError): + checkpoint_sequential(model, 1) def test_checkpoint_rng_cpu(self): for _ in range(5): diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index a6e3add335b2..bfedab7eaaf5 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -155,10 +155,7 @@ def checkpoint(function, *args, **kwargs): return CheckpointFunction.apply(function, preserve, *args) -# TODO(sublee): When releasing PyTorch 1.3, -# fix the function signature to not accept variadic arguments. -# See also: https://github.com/pytorch/pytorch/issues/19260 -def checkpoint_sequential(functions, segments, *inputs, **kwargs): +def checkpoint_sequential(functions, segments, input, **kwargs): r"""A helper function for checkpointing sequential models. Sequential models execute a list of modules/functions in order @@ -179,11 +176,15 @@ def checkpoint_sequential(functions, segments, *inputs, **kwargs): grads are needed for model inputs, otherwise the checkpointed part of the model won't have gradients. + .. warning: + Since PyTorch 1.4, it allows only one Tensor as the input and + intermediate outputs, just like :class:`torch.nn.Sequential`. + Args: functions: A :class:`torch.nn.Sequential` or the list of modules or functions (comprising the model) to run sequentially. segments: Number of chunks to create in the model - inputs: tuple of Tensors that are inputs to :attr:`functions` + input: A Tensor that is input to :attr:`functions` preserve_rng_state(bool, optional, default=True): Omit stashing and restoring the RNG state during each checkpoint. @@ -194,31 +195,16 @@ def checkpoint_sequential(functions, segments, *inputs, **kwargs): >>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var) """ - # Hack to mix *args with **kwargs in a python 2.7-compliant way + # Hack for keyword-only parameter in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) if kwargs: raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) - # To accept variadic arguments is not consistent with nn.Sequential. - # This interface will be changed at PyTorch 1.3. - # See also: https://github.com/pytorch/pytorch/issues/19260 - if not inputs: - warnings.warn('Giving no input to checkpoint_sequential has been deprecated, ' - 'a TypeError will be raised after PyTorch 1.3', - DeprecationWarning) - elif len(inputs) > 1: - warnings.warn('multiple inputs to checkpoint_sequential has been deprecated, ' - 'a TypeError will be raised after PyTorch 1.3', - DeprecationWarning) - def run_function(start, end, functions): - def forward(*inputs): + def forward(input): for j in range(start, end + 1): - if isinstance(inputs, tuple): - inputs = functions[j](*inputs) - else: - inputs = functions[j](inputs) - return inputs + input = functions[j](input) + return input return forward if isinstance(functions, torch.nn.Sequential): @@ -229,8 +215,6 @@ def checkpoint_sequential(functions, segments, *inputs, **kwargs): end = -1 for start in range(0, segment_size * (segments - 1), segment_size): end = start + segment_size - 1 - inputs = checkpoint(run_function(start, end, functions), *inputs, - preserve_rng_state=preserve) - if not isinstance(inputs, tuple): - inputs = (inputs,) - return run_function(end + 1, len(functions) - 1, functions)(*inputs) + input = checkpoint(run_function(start, end, functions), input, + preserve_rng_state=preserve) + return run_function(end + 1, len(functions) - 1, functions)(input)