[decompositions] add decomposition for RNN with packed sequence (#91281)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91281
Approved by: https://github.com/zou3519
This commit is contained in:
lezcano
2023-02-08 09:36:20 +00:00
committed by PyTorch MergeBot
parent e5f6e1f660
commit bef61225c3
7 changed files with 270 additions and 21 deletions

View File

@ -23,6 +23,7 @@ import unittest
import warnings
import itertools
from functools import partial
from torch.nn.utils.rnn import PackedSequence
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
from torch.testing._internal.common_modules import module_db, modules
@ -2517,11 +2518,17 @@ def _test_aot_autograd_module_helper(self, device, dtype, training, module_info)
# Lazy modules need to see an input first to initialize params.
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
# PackedSequence is only used for RNNs. It might be possible to fake-ify if they're pytrees but
# torchdynamo already doesn't support RNNs
if any(tuple(isinstance(flat_arg, PackedSequence) for flat_arg in flat_args)):
continue
if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin):
with torch.no_grad():
m(*args, **kwargs)
flat_args, args_spec = pytree.tree_flatten((args, kwargs))
sentinel_val = -42
is_tensor_spec = [sentinel_val if isinstance(arg, torch.Tensor)
else arg for arg in flat_args]

View File

@ -443,7 +443,6 @@ class TestDecomp(TestCase):
# they're checking aten decomps at the torch_dispatch level
self.assertEqual(decomp_out, non_decomp_out)
class DecompCrossRefMode(TorchDispatchMode):
def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all):
self.test_case = test_case

View File

@ -80,13 +80,14 @@ class TestExpandedWeightHelperFunction(TestCase):
def test_set_grad_sample_if_exists(self, device):
def test_fn(a):
return True
return grad_sample
orig_weight = torch.randn(4, device=device, requires_grad=True)
expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum")
grad_sample = torch.randn(3)
set_grad_sample_if_exists(expanded_weight, test_fn)
self.assertTrue(hasattr(orig_weight, 'grad_sample'))
self.assertTrue(orig_weight.grad_sample)
self.assertEqual(orig_weight.grad_sample, grad_sample)
basic_tensor = torch.randn(4, device=device)
set_grad_sample_if_exists(basic_tensor, test_fn)
@ -474,6 +475,43 @@ class TestExpandedWeightModule(TestCase):
expected_grads = tuple(expected_grad for expected_grad in expected_grads if expected_grad is not None)
assert [self.assertEqual(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads)]
def _do_test_rnn_packed_sequence(self, module, input, args=None, kwargs=None):
args = args if args is not None else ()
kwargs = kwargs if kwargs is not None else {}
batch_size = max(tuple(input.batch_sizes)).item()
with freeze_rng_state():
# get per sample grads with ExpandedWeights context manager
actual_res = call_for_per_sample_grads(module,
batch_size=batch_size,
loss_reduction="sum")(input, *args, **kwargs).data.sum()
actual_res.backward()
actual_grads = []
for param in module.parameters():
self.assertEqual(param.grad_sample.shape[0], batch_size)
actual_grads.append(param.grad_sample)
del param.grad_sample
input.data.grad = torch.zeros_like(input.data)
# compute the per sample grads with a for loop
expected_res = torch.zeros_like(actual_res)
expected_grads = []
padded_input, seq_sizes = torch.nn.utils.rnn.pad_packed_sequence(input, batch_first=True)
for i in range(len(seq_sizes)):
input_slice = padded_input[i].narrow(0, 0, seq_sizes[i])
diff_params = module.parameters()
batch_dim = 0 if module.m.batch_first else 1
res = module(input_slice.unsqueeze(batch_dim), *args, **kwargs).sum()
expected_res += res
out_grads = torch.autograd.grad(res, diff_params, torch.ones_like(res), allow_unused=True)
expected_grads.append(out_grads)
expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
self.assertEqual(actual_res, expected_res)
[self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
@modules(filter(lambda m_info: m_info.module_cls in (torch.nn.RNN, torch.nn.LSTM), module_db))
def test_module(self, device, dtype, module_info, training):
class RNNWrapper(torch.nn.Module):
@ -494,7 +532,7 @@ class TestExpandedWeightModule(TestCase):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True, training=training)
requires_grad=True, training=training, with_packed_sequence=True)
for module_input in module_inputs:
if module_input.forward_input is None:
continue
@ -506,8 +544,9 @@ class TestExpandedWeightModule(TestCase):
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
# if the RNN tests use unbatched inputs--batch the inputs
input = args[0].detach()
if input.dim() == 2:
input = args[0]
if isinstance(input, torch.Tensor) and input.dim() == 2:
input = input.detach()
new_input_shape = [1] * (len(input.shape) + 1)
if batch_first:
new_input_shape[0] = 2
@ -522,7 +561,10 @@ class TestExpandedWeightModule(TestCase):
args = list(args)
args[1] = h
self._do_test(m, input, args[1:], kwargs, batch_first=batch_first)
if isinstance(input, torch.nn.utils.rnn.PackedSequence):
self._do_test_rnn_packed_sequence(m, input, args[1:], kwargs)
else:
self._do_test(m, input, args[1:], kwargs, batch_first=batch_first)
def test_per_sample_api_failing(self):
module = nn.Linear(10, 10)

View File

@ -2149,6 +2149,73 @@ def params_hiddens(params, hiddens, i, bidirectional):
return cur_params, cur_hidden, bidir_params, bidir_hidden
def update_hidden_for_packed(cur_hidden, last_batch_size, batch_size, hiddens):
assert last_batch_size > batch_size
hiddens.append(cur_hidden.narrow(0, batch_size, last_batch_size - batch_size))
return cur_hidden.narrow(0, 0, batch_size)
def update_hidden_for_packed_reverse(
cur_hidden, last_batch_size, batch_size, inp_hidden
):
if last_batch_size == batch_size:
return cur_hidden
assert last_batch_size < batch_size
return torch.concat(
(
cur_hidden,
inp_hidden.narrow(0, last_batch_size, batch_size - last_batch_size),
)
)
def one_layer_rnn_data(
inp, hidden, params, has_biases, nonlinearity, batch_sizes, reverse=False
):
ih_weight = params[0]
hh_weight = params[1]
ih_bias = params[2] if has_biases else None
hh_bias = params[3] if has_biases else None
step_output = []
hiddens: List["torch.Tensor"] = []
last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
cur_hidden = hidden.narrow(0, 0, last_batch_size)
split_inp = torch.split(inp, list(batch_sizes))
if reverse:
split_inp = split_inp[::-1]
for inp in split_inp:
i = inp.shape[0]
if last_batch_size == i:
pass # don't update cur_hidden
# this will only happen when reverse=False, since batch sizes are sorted largest -> smallest
elif reverse:
cur_hidden = update_hidden_for_packed_reverse(
cur_hidden, last_batch_size, i, hidden
)
else:
cur_hidden = update_hidden_for_packed(
cur_hidden, last_batch_size, i, hiddens
)
inp = F.linear(inp, ih_weight, ih_bias)
cur_hidden = nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + inp)
last_batch_size = i
step_output.append(cur_hidden)
if reverse:
step_output.reverse()
else:
hiddens.append(cur_hidden)
hiddens.reverse()
out = torch.cat(step_output, 0)
hidden_out = torch.cat(hiddens, 0) if not reverse else cur_hidden
return out, hidden_out
def one_layer_rnn(inp, hidden, params, has_biases, nonlinearity, reverse=False):
ih_weight = params[0]
hh_weight = params[1]
@ -2163,6 +2230,9 @@ def one_layer_rnn(inp, hidden, params, has_biases, nonlinearity, reverse=False):
cur_hidden = nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + inp)
step_output.append(cur_hidden)
if reverse:
step_output.reverse()
out = torch.cat(step_output, 0)
return out, cur_hidden.squeeze(0)
@ -2195,7 +2265,6 @@ def _rnn_helper(
bwd_inp, bwd_hidden = layer_fn(
input, bidir_hidden, bidir_params, has_biases, reverse=True
)
bwd_inp = bwd_inp.flip(0)
final_hiddens.append(bwd_hidden)
if bidirectional:
@ -2272,6 +2341,68 @@ def rnn_relu_input(
return out, torch.stack(final_hiddens, 0)
@register_decomposition(aten.rnn_relu.data)
@aten.rnn_relu.data.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.rnn_relu.data.py_impl(DispatchKey.Autograd)
def rnn_relu_data(
data,
batch_sizes,
hx,
params,
has_biases,
num_layers,
dropout,
train,
bidirectional,
):
hidden = hx.unbind(0)
params = gather_params(params, has_biases, False)
out, final_hiddens = _rnn_helper(
data,
hidden,
params,
has_biases,
num_layers,
dropout,
train,
bidirectional,
False,
partial(one_layer_rnn_data, batch_sizes=batch_sizes, nonlinearity=torch.relu),
)
return out, torch.stack(final_hiddens, 0)
@register_decomposition(aten.rnn_tanh.data)
@aten.rnn_tanh.data.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.rnn_tanh.data.py_impl(DispatchKey.Autograd)
def rnn_tanh_data(
data,
batch_sizes,
hx,
params,
has_biases,
num_layers,
dropout,
train,
bidirectional,
):
hidden = hx.unbind(0)
params = gather_params(params, has_biases, False)
out, final_hiddens = _rnn_helper(
data,
hidden,
params,
has_biases,
num_layers,
dropout,
train,
bidirectional,
False,
partial(one_layer_rnn_data, batch_sizes=batch_sizes, nonlinearity=torch.tanh),
)
return out, torch.stack(final_hiddens, 0)
def one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
ih_weight = params[0]
hh_weight = params[1]
@ -2302,6 +2433,9 @@ def one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
hx = hy
cx = cy
if reverse:
step_output.reverse()
out = torch.cat(step_output, 0)
return out, (hx.squeeze(1), cx.squeeze(1))

View File

@ -11,24 +11,54 @@ from torch.utils._pytree import tree_map_only
HANDLED_FUNCTIONS: Dict[Callable, torch.autograd.Function] = {}
# __torch_function__ runs before the pydispatcher so we need to use the same
aten = torch._ops.ops.aten
# __torch_function__ runs before the pydispatcher so we need to manually use the same
# decompositions indexed by their torch equivalent
expanded_weights_rnn_decomps = {
# func: (input_decomp, data_decomp)
torch.rnn_relu: (decomposition_table[torch._ops.ops.aten.rnn_relu.input], None),
torch.rnn_tanh: (decomposition_table[torch._ops.ops.aten.rnn_tanh.input], None),
torch.lstm: (decomposition_table[torch._ops.ops.aten.lstm.input], None),
torch.rnn_relu: (decomposition_table[aten.rnn_relu.input], decomposition_table[aten.rnn_relu.data]),
torch.rnn_tanh: (decomposition_table[aten.rnn_tanh.input], decomposition_table[aten.rnn_tanh.data]),
torch.lstm: (decomposition_table[aten.lstm.input], None),
}
# all of the RNN decomps run linear with the batch dimension second, even if batch_first was set
@contextmanager
def batch_second(args, kwargs):
tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=False), args)
tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=False), kwargs)
def set_batch_second(ew):
ew.set_batch_first(False)
def reset_batch_first(ew):
ew.set_batch_first(True)
tree_map_only(ExpandedWeight, set_batch_second, args)
tree_map_only(ExpandedWeight, set_batch_second, kwargs)
try:
yield
finally:
tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=True), args)
tree_map_only(ExpandedWeight, functools.partial(ExpandedWeight.set_batch_first, is_batch_first=True), kwargs)
tree_map_only(ExpandedWeight, reset_batch_first, args)
tree_map_only(ExpandedWeight, reset_batch_first, kwargs)
# to support packed sequences, we need to allow for smaller batches. Expanded weights represents the largest batch
@contextmanager
def allow_smaller_batches(args, kwargs):
def allow(ew):
ew.set_allow_smaller_batches(True)
def reset(ew):
ew.set_allow_smaller_batches(False)
tree_map_only(ExpandedWeight, allow, args)
tree_map_only(ExpandedWeight, allow, kwargs)
try:
yield
finally:
tree_map_only(ExpandedWeight, reset, args)
tree_map_only(ExpandedWeight, reset, kwargs)
@contextmanager
def setup_rnn(use_input_variant, args, kwargs):
with batch_second(args, kwargs) if use_input_variant else allow_smaller_batches(args, kwargs):
yield
def implements_per_sample_grads(torch_function):
@ -54,6 +84,7 @@ class ExpandedWeight(torch.Tensor):
def __init__(self, orig_weight, batch_size, loss_reduction):
self.batch_size = batch_size
self.batch_first = True
self.allow_smaller_batches = False
self.orig_weight = orig_weight
self.loss_reduction = loss_reduction
@ -74,11 +105,11 @@ class ExpandedWeight(torch.Tensor):
if func in expanded_weights_rnn_decomps:
# in aten, choosing the input or data variants is done by parsing logic. This mimics some of that
decomp_opts = expanded_weights_rnn_decomps[func]
use_input_variant = not isinstance(args[1], list) # data variant uses a list here
use_input_variant = isinstance(args[2], list) # data variant uses a list here
decomp = decomp_opts[0] if use_input_variant else decomp_opts[1]
if decomp is not None:
with batch_second(args, kwargs):
with setup_rnn(use_input_variant, args, kwargs):
return decomp(*args, **kwargs)
if func == torch._cudnn_rnn_flatten_weight:
# since we aren't using the fused cuda kernels for RNNs, don't do this
@ -115,5 +146,8 @@ class ExpandedWeight(torch.Tensor):
def get_device(self):
return self.orig_weight.get_device()
def set_allow_smaller_batches(self, is_allow_smaller_batches):
self.allow_smaller_batches = is_allow_smaller_batches
def set_batch_first(self, is_batch_first=True):
self.batch_first = is_batch_first

View File

@ -62,7 +62,8 @@ def _check_and_unexpand_args(func, expanded_args, expanded_kwargs):
if not isinstance(arg, ExpandedWeight):
continue
batch_size = input.shape[0] if arg.batch_first else input.shape[1]
if arg.batch_size != batch_size:
if (arg.allow_smaller_batches and batch_size > arg.batch_size) or \
(not arg.allow_smaller_batches and arg.batch_size != batch_size):
raise RuntimeError("Expected ExpandedWeights to have batch size matching input but got "
f"input batch size of {batch_size} with ExpandedWeight of batch size {arg.batch_size}")
@ -90,6 +91,15 @@ def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn):
unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight)
if isinstance(maybe_expanded_weight, ExpandedWeight):
grad_sample_contribution = maybe_scale_by_batch_size(per_sample_grad_fn(unpacked), maybe_expanded_weight)
if maybe_expanded_weight.batch_size > grad_sample_contribution.shape[0]:
# this only passes the other checks if the arg allows smaller batch sizes
intermediate = torch.zeros(maybe_expanded_weight.batch_size, *grad_sample_contribution.shape[1:],
dtype=grad_sample_contribution.dtype,
device=grad_sample_contribution.device)
intermediate[:grad_sample_contribution.shape[0]] = grad_sample_contribution
grad_sample_contribution = intermediate
if hasattr(unpacked, "grad_sample") and unpacked.grad_sample is not None:
unpacked.grad_sample = unpacked.grad_sample + grad_sample_contribution
else:

View File

@ -6,6 +6,7 @@ from functools import wraps, partial
from itertools import chain, product
import itertools
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import TEST_CUDNN
from torch.testing._internal.common_dtype import floating_types, floating_and_complex_types_and
@ -947,8 +948,15 @@ def module_inputs_torch_nn_LSTMCell(module_info, device, dtype, requires_grad, t
return samples
def make_packed_sequence(inp, batch_sizes):
required_grad = inp.requires_grad
inp.requires_grad_(False) # user won't have access to inp so won't be able to get its grads
seq = pack_padded_sequence(inp, batch_sizes)
seq.data.requires_grad_(required_grad)
return seq
def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, **kwargs):
def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, training, with_packed_sequence=False, **kwargs):
# Currently all samples below are for validating the no-batch-dim support.
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
is_rnn = kwargs['is_rnn']
@ -991,6 +999,21 @@ def module_inputs_torch_nn_RNN_GRU(module_info, device, dtype, requires_grad, tr
reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
)
)
if with_packed_sequence:
samples.append(
ModuleInput(
constructor_input=FunctionInput(**cons_args),
forward_input=FunctionInput(make_packed_sequence(make_input((5, 2, 2)), torch.tensor([5, 3]))),
reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
)
)
samples.append(
ModuleInput(
constructor_input=FunctionInput(**cons_args),
forward_input=FunctionInput(make_packed_sequence(make_input((5, 5, 2)), torch.tensor([5, 3, 3, 2, 2]))),
reference_fn=partial(no_batch_dim_reference_rnn_gru, batch_first=b_f),
)
)
return samples