From bef61225c39bb2df67f5db54c12dadd36ae272ab Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 8 Feb 2023 09:36:20 +0000 Subject: [PATCH] [decompositions] add decomposition for RNN with packed sequence (#91281) Pull Request resolved: https://github.com/pytorch/pytorch/pull/91281 Approved by: https://github.com/zou3519 --- test/functorch/test_aotdispatch.py | 9 +- test/test_decomp.py | 1 - test/test_expanded_weights.py | 54 ++++++- torch/_decomp/decompositions.py | 136 +++++++++++++++++- .../expanded_weights_impl.py | 54 +++++-- .../expanded_weights_utils.py | 12 +- torch/testing/_internal/common_modules.py | 25 +++- 7 files changed, 270 insertions(+), 21 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 2cb68a0a3a58..994aa9e7da73 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -23,6 +23,7 @@ import unittest import warnings import itertools from functools import partial +from torch.nn.utils.rnn import PackedSequence from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed from torch.testing._internal.common_modules import module_db, modules @@ -2517,11 +2518,17 @@ def _test_aot_autograd_module_helper(self, device, dtype, training, module_info) # Lazy modules need to see an input first to initialize params. args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + + # PackedSequence is only used for RNNs. It might be possible to fake-ify if they're pytrees but + # torchdynamo already doesn't support RNNs + if any(tuple(isinstance(flat_arg, PackedSequence) for flat_arg in flat_args)): + continue + if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin): with torch.no_grad(): m(*args, **kwargs) - flat_args, args_spec = pytree.tree_flatten((args, kwargs)) sentinel_val = -42 is_tensor_spec = [sentinel_val if isinstance(arg, torch.Tensor) else arg for arg in flat_args] diff --git a/test/test_decomp.py b/test/test_decomp.py index 966817985d8f..776c1b328fdc 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -443,7 +443,6 @@ class TestDecomp(TestCase): # they're checking aten decomps at the torch_dispatch level self.assertEqual(decomp_out, non_decomp_out) - class DecompCrossRefMode(TorchDispatchMode): def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all): self.test_case = test_case diff --git a/test/test_expanded_weights.py b/test/test_expanded_weights.py index bb982dc4fc29..d0973714509e 100644 --- a/test/test_expanded_weights.py +++ b/test/test_expanded_weights.py @@ -80,13 +80,14 @@ class TestExpandedWeightHelperFunction(TestCase): def test_set_grad_sample_if_exists(self, device): def test_fn(a): - return True + return grad_sample orig_weight = torch.randn(4, device=device, requires_grad=True) expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum") + grad_sample = torch.randn(3) set_grad_sample_if_exists(expanded_weight, test_fn) self.assertTrue(hasattr(orig_weight, 'grad_sample')) - self.assertTrue(orig_weight.grad_sample) + self.assertEqual(orig_weight.grad_sample, grad_sample) basic_tensor = torch.randn(4, device=device) set_grad_sample_if_exists(basic_tensor, test_fn) @@ -474,6 +475,43 @@ class TestExpandedWeightModule(TestCase): expected_grads = tuple(expected_grad for expected_grad in expected_grads if expected_grad is not None) assert [self.assertEqual(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads)] + def _do_test_rnn_packed_sequence(self, module, input, args=None, kwargs=None): + args = args if args is not None else () + kwargs = kwargs if kwargs is not None else {} + + batch_size = max(tuple(input.batch_sizes)).item() + + with freeze_rng_state(): + # get per sample grads with ExpandedWeights context manager + actual_res = call_for_per_sample_grads(module, + batch_size=batch_size, + loss_reduction="sum")(input, *args, **kwargs).data.sum() + actual_res.backward() + actual_grads = [] + for param in module.parameters(): + self.assertEqual(param.grad_sample.shape[0], batch_size) + actual_grads.append(param.grad_sample) + del param.grad_sample + + input.data.grad = torch.zeros_like(input.data) + + # compute the per sample grads with a for loop + expected_res = torch.zeros_like(actual_res) + expected_grads = [] + padded_input, seq_sizes = torch.nn.utils.rnn.pad_packed_sequence(input, batch_first=True) + for i in range(len(seq_sizes)): + input_slice = padded_input[i].narrow(0, 0, seq_sizes[i]) + diff_params = module.parameters() + batch_dim = 0 if module.m.batch_first else 1 + res = module(input_slice.unsqueeze(batch_dim), *args, **kwargs).sum() + expected_res += res + out_grads = torch.autograd.grad(res, diff_params, torch.ones_like(res), allow_unused=True) + expected_grads.append(out_grads) + + expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)] + self.assertEqual(actual_res, expected_res) + [self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)] + @modules(filter(lambda m_info: m_info.module_cls in (torch.nn.RNN, torch.nn.LSTM), module_db)) def test_module(self, device, dtype, module_info, training): class RNNWrapper(torch.nn.Module): @@ -494,7 +532,7 @@ class TestExpandedWeightModule(TestCase): module_cls = module_info.module_cls module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, - requires_grad=True, training=training) + requires_grad=True, training=training, with_packed_sequence=True) for module_input in module_inputs: if module_input.forward_input is None: continue @@ -506,8 +544,9 @@ class TestExpandedWeightModule(TestCase): args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs # if the RNN tests use unbatched inputs--batch the inputs - input = args[0].detach() - if input.dim() == 2: + input = args[0] + if isinstance(input, torch.Tensor) and input.dim() == 2: + input = input.detach() new_input_shape = [1] * (len(input.shape) + 1) if batch_first: new_input_shape[0] = 2 @@ -522,7 +561,10 @@ class TestExpandedWeightModule(TestCase): args = list(args) args[1] = h - self._do_test(m, input, args[1:], kwargs, batch_first=batch_first) + if isinstance(input, torch.nn.utils.rnn.PackedSequence): + self._do_test_rnn_packed_sequence(m, input, args[1:], kwargs) + else: + self._do_test(m, input, args[1:], kwargs, batch_first=batch_first) def test_per_sample_api_failing(self): module = nn.Linear(10, 10) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index d88db1a7c63b..aefa7ff2156b 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -2149,6 +2149,73 @@ def params_hiddens(params, hiddens, i, bidirectional): return cur_params, cur_hidden, bidir_params, bidir_hidden +def update_hidden_for_packed(cur_hidden, last_batch_size, batch_size, hiddens): + assert last_batch_size > batch_size + hiddens.append(cur_hidden.narrow(0, batch_size, last_batch_size - batch_size)) + return cur_hidden.narrow(0, 0, batch_size) + + +def update_hidden_for_packed_reverse( + cur_hidden, last_batch_size, batch_size, inp_hidden +): + if last_batch_size == batch_size: + return cur_hidden + assert last_batch_size < batch_size + return torch.concat( + ( + cur_hidden, + inp_hidden.narrow(0, last_batch_size, batch_size - last_batch_size), + ) + ) + + +def one_layer_rnn_data( + inp, hidden, params, has_biases, nonlinearity, batch_sizes, reverse=False +): + ih_weight = params[0] + hh_weight = params[1] + ih_bias = params[2] if has_biases else None + hh_bias = params[3] if has_biases else None + + step_output = [] + hiddens: List["torch.Tensor"] = [] + + last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0] + cur_hidden = hidden.narrow(0, 0, last_batch_size) + split_inp = torch.split(inp, list(batch_sizes)) + if reverse: + split_inp = split_inp[::-1] + for inp in split_inp: + i = inp.shape[0] + + if last_batch_size == i: + pass # don't update cur_hidden + # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest + elif reverse: + cur_hidden = update_hidden_for_packed_reverse( + cur_hidden, last_batch_size, i, hidden + ) + else: + cur_hidden = update_hidden_for_packed( + cur_hidden, last_batch_size, i, hiddens + ) + + inp = F.linear(inp, ih_weight, ih_bias) + cur_hidden = nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + inp) + last_batch_size = i + step_output.append(cur_hidden) + + if reverse: + step_output.reverse() + else: + hiddens.append(cur_hidden) + hiddens.reverse() + + out = torch.cat(step_output, 0) + hidden_out = torch.cat(hiddens, 0) if not reverse else cur_hidden + return out, hidden_out + + def one_layer_rnn(inp, hidden, params, has_biases, nonlinearity, reverse=False): ih_weight = params[0] hh_weight = params[1] @@ -2163,6 +2230,9 @@ def one_layer_rnn(inp, hidden, params, has_biases, nonlinearity, reverse=False): cur_hidden = nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + inp) step_output.append(cur_hidden) + if reverse: + step_output.reverse() + out = torch.cat(step_output, 0) return out, cur_hidden.squeeze(0) @@ -2195,7 +2265,6 @@ def _rnn_helper( bwd_inp, bwd_hidden = layer_fn( input, bidir_hidden, bidir_params, has_biases, reverse=True ) - bwd_inp = bwd_inp.flip(0) final_hiddens.append(bwd_hidden) if bidirectional: @@ -2272,6 +2341,68 @@ def rnn_relu_input( return out, torch.stack(final_hiddens, 0) +@register_decomposition(aten.rnn_relu.data) +@aten.rnn_relu.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_relu.data.py_impl(DispatchKey.Autograd) +def rnn_relu_data( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + data, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial(one_layer_rnn_data, batch_sizes=batch_sizes, nonlinearity=torch.relu), + ) + return out, torch.stack(final_hiddens, 0) + + +@register_decomposition(aten.rnn_tanh.data) +@aten.rnn_tanh.data.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.rnn_tanh.data.py_impl(DispatchKey.Autograd) +def rnn_tanh_data( + data, + batch_sizes, + hx, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, +): + hidden = hx.unbind(0) + params = gather_params(params, has_biases, False) + out, final_hiddens = _rnn_helper( + data, + hidden, + params, + has_biases, + num_layers, + dropout, + train, + bidirectional, + False, + partial(one_layer_rnn_data, batch_sizes=batch_sizes, nonlinearity=torch.tanh), + ) + return out, torch.stack(final_hiddens, 0) + + def one_layer_lstm(inp, hidden, params, has_biases, reverse=False): ih_weight = params[0] hh_weight = params[1] @@ -2302,6 +2433,9 @@ def one_layer_lstm(inp, hidden, params, has_biases, reverse=False): hx = hy cx = cy + if reverse: + step_output.reverse() + out = torch.cat(step_output, 0) return out, (hx.squeeze(1), cx.squeeze(1)) diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py index 0702cdbc3390..1997fab3fb10 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -11,24 +11,54 @@ from torch.utils._pytree import tree_map_only HANDLED_FUNCTIONS: Dict[Callable, torch.autograd.Function] = {} -# __torch_function__ runs before the pydispatcher so we need to use the same +aten = torch._ops.ops.aten +# __torch_function__ runs before the pydispatcher so we need to manually use the same # decompositions indexed by their torch equivalent expanded_weights_rnn_decomps = { # func: (input_decomp, data_decomp) - torch.rnn_relu: (decomposition_table[torch._ops.ops.aten.rnn_relu.input], None), - torch.rnn_tanh: (decomposition_table[torch._ops.ops.aten.rnn_tanh.input], None), - torch.lstm: (decomposition_table[torch._ops.ops.aten.lstm.input], None), + torch.rnn_relu: (decomposition_table[aten.rnn_relu.input], decomposition_table[aten.rnn_relu.data]), + torch.rnn_tanh: (decomposition_table[aten.rnn_tanh.input], decomposition_table[aten.rnn_tanh.data]), + torch.lstm: (decomposition_table[aten.lstm.input], None), } +# all of the RNN decomps run linear with the batch dimension second, even if batch_first was set @contextmanager def batch_second(args, kwargs): - tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=False), args) - tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=False), kwargs) + def set_batch_second(ew): + ew.set_batch_first(False) + + def reset_batch_first(ew): + ew.set_batch_first(True) + + tree_map_only(ExpandedWeight, set_batch_second, args) + tree_map_only(ExpandedWeight, set_batch_second, kwargs) try: yield finally: - tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=True), args) - tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=True), kwargs) + tree_map_only(ExpandedWeight, reset_batch_first, args) + tree_map_only(ExpandedWeight, reset_batch_first, kwargs) + +# to support packed sequences, we need to allow for smaller batches. Expanded weights represents the largest batch +@contextmanager +def allow_smaller_batches(args, kwargs): + def allow(ew): + ew.set_allow_smaller_batches(True) + + def reset(ew): + ew.set_allow_smaller_batches(False) + + tree_map_only(ExpandedWeight, allow, args) + tree_map_only(ExpandedWeight, allow, kwargs) + try: + yield + finally: + tree_map_only(ExpandedWeight, reset, args) + tree_map_only(ExpandedWeight, reset, kwargs) + +@contextmanager +def setup_rnn(use_input_variant, args, kwargs): + with batch_second(args, kwargs) if use_input_variant else allow_smaller_batches(args, kwargs): + yield def implements_per_sample_grads(torch_function): @@ -54,6 +84,7 @@ class ExpandedWeight(torch.Tensor): def __init__(self, orig_weight, batch_size, loss_reduction): self.batch_size = batch_size self.batch_first = True + self.allow_smaller_batches = False self.orig_weight = orig_weight self.loss_reduction = loss_reduction @@ -74,11 +105,11 @@ class ExpandedWeight(torch.Tensor): if func in expanded_weights_rnn_decomps: # in aten, choosing the input or data variants is done by parsing logic. This mimics some of that decomp_opts = expanded_weights_rnn_decomps[func] - use_input_variant = not isinstance(args[1], list) # data variant uses a list here + use_input_variant = isinstance(args[2], list) # data variant uses a list here decomp = decomp_opts[0] if use_input_variant else decomp_opts[1] if decomp is not None: - with batch_second(args, kwargs): + with setup_rnn(use_input_variant, args, kwargs): return decomp(*args, **kwargs) if func == torch._cudnn_rnn_flatten_weight: # since we aren't using the fused cuda kernels for RNNs, don't do this @@ -115,5 +146,8 @@ class ExpandedWeight(torch.Tensor): def get_device(self): return self.orig_weight.get_device() + def set_allow_smaller_batches(self, is_allow_smaller_batches): + self.allow_smaller_batches = is_allow_smaller_batches + def set_batch_first(self, is_batch_first=True): self.batch_first = is_batch_first diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py index 0f429bbdb222..b3c91481c18c 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -62,7 +62,8 @@ def _check_and_unexpand_args(func, expanded_args, expanded_kwargs): if not isinstance(arg, ExpandedWeight): continue batch_size = input.shape[0] if arg.batch_first else input.shape[1] - if arg.batch_size != batch_size: + if (arg.allow_smaller_batches and batch_size > arg.batch_size) or \ + (not arg.allow_smaller_batches and arg.batch_size != batch_size): raise RuntimeError("Expected ExpandedWeights to have batch size matching input but got " f"input batch size of {batch_size} with ExpandedWeight of batch size {arg.batch_size}") @@ -90,6 +91,15 @@ def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn): unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight) if isinstance(maybe_expanded_weight, ExpandedWeight): grad_sample_contribution = maybe_scale_by_batch_size(per_sample_grad_fn(unpacked), maybe_expanded_weight) + + if maybe_expanded_weight.batch_size > grad_sample_contribution.shape[0]: + # this only passes the other checks if the arg allows smaller batch sizes + intermediate = torch.zeros(maybe_expanded_weight.batch_size, *grad_sample_contribution.shape[1:], + dtype=grad_sample_contribution.dtype, + device=grad_sample_contribution.device) + intermediate[:grad_sample_contribution.shape[0]] = grad_sample_contribution + grad_sample_contribution = intermediate + if hasattr(unpacked, "grad_sample") and unpacked.grad_sample is not None: unpacked.grad_sample = unpacked.grad_sample + grad_sample_contribution else: diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 490bef08979e..12c54668d848 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -6,6 +6,7 @@ from functools import wraps, partial from itertools import chain, product import itertools import torch.nn.functional as F +from torch.nn.utils.rnn import pack_padded_sequence from torch.testing import make_tensor from torch.testing._internal.common_cuda import TEST_CUDNN from torch.testing._internal.common_dtype import floating_types, floating_and_complex_types_and @@ -947,8 +948,15 @@ def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, t return samples +def make_packed_sequence(inp, batch_sizes): + required_grad = inp.requires_grad + inp.requires_grad_(False) # user won't have access to inp so won't be able to get its grads + seq = pack_padded_sequence(inp, batch_sizes) + seq.data.requires_grad_(required_grad) + return seq -def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs): + +def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, with_packed_sequence=False, **kwargs): # Currently all samples below are for validating the no-batch-dim support. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) is_rnn = kwargs['is_rnn'] @@ -991,6 +999,21 @@ def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, tr reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f), ) ) + if with_packed_sequence: + samples.append( + ModuleInput( + constructor_input=FunctionInput(**cons_args), + forward_input=FunctionInput(make_packed_sequence(make_input((5, 2, 2)), torch.tensor([5, 3]))), + reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f), + ) + ) + samples.append( + ModuleInput( + constructor_input=FunctionInput(**cons_args), + forward_input=FunctionInput(make_packed_sequence(make_input((5, 5, 2)), torch.tensor([5, 3, 3, 2, 2]))), + reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f), + ) + ) return samples