[decompositions] GRU decompositon with and without packed sequence (#91466)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91466
Approved by: https://github.com/zou3519
This commit is contained in:
lezcano
2023-02-08 09:36:21 +00:00
committed by PyTorch MergeBot
parent 5a7c1b7894
commit fe0e28ab87
4 changed files with 120 additions and 18 deletions

View File

@ -422,7 +422,7 @@ class TestDecomp(TestCase):
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@suppress_warnings
# only tests RNNs since we have py dispsatcher decomps for them
@modules(filter(lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM), module_db))
@modules(filter(lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), module_db))
def test_rnn_decomp_module(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,

View File

@ -387,7 +387,7 @@ class TestExpandedWeightFunctional(TestCase):
F.group_norm(inp, 2) # 5 is not divisible by 2
class TestExpandedWeightModule(TestCase):
def _do_test(self, module, input, args=None, kwargs=None, batch_first=True):
def _do_test(self, module, input, args=None, kwargs=None, batch_first=True, atol=None, rtol=None):
args = args or ()
kwargs = kwargs or {}
@ -432,7 +432,7 @@ class TestExpandedWeightModule(TestCase):
if not batch_first:
expected_grads[-1] = expected_grads[-1].transpose(0, 1)
self.assertEqual(actual_res, expected_res)
[self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
[self.assertEqual(actual, expected, atol=atol, rtol=rtol) for (actual, expected) in zip(actual_grads, expected_grads)]
def _do_test_multi_input(self, module, input):
class TestModule(nn.Module):
@ -475,7 +475,7 @@ class TestExpandedWeightModule(TestCase):
expected_grads = tuple(expected_grad for expected_grad in expected_grads if expected_grad is not None)
assert [self.assertEqual(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads)]
def _do_test_rnn_packed_sequence(self, module, input, args=None, kwargs=None):
def _do_test_rnn_packed_sequence(self, module, input, args=None, kwargs=None, atol=None, rtol=None):
args = args if args is not None else ()
kwargs = kwargs if kwargs is not None else {}
@ -510,9 +510,9 @@ class TestExpandedWeightModule(TestCase):
expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
self.assertEqual(actual_res, expected_res)
[self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
[self.assertEqual(actual, expected, atol=atol, rtol=rtol) for (actual, expected) in zip(actual_grads, expected_grads)]
@modules(filter(lambda m_info: m_info.module_cls in (torch.nn.RNN, torch.nn.LSTM), module_db))
@modules(filter(lambda m_info: m_info.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), module_db))
def test_module(self, device, dtype, module_info, training):
class RNNWrapper(torch.nn.Module):
def __init__(self, m_cons, args, kwargs):
@ -531,6 +531,7 @@ class TestExpandedWeightModule(TestCase):
module_cls = module_info.module_cls
atol, rtol = (1e-4, 1e-5) if module_cls == torch.nn.GRU and dtype == torch.float32 else (None, None)
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True, training=training, with_packed_sequence=True)
for module_input in module_inputs:
@ -562,9 +563,9 @@ class TestExpandedWeightModule(TestCase):
args[1] = h
if isinstance(input, torch.nn.utils.rnn.PackedSequence):
self._do_test_rnn_packed_sequence(m, input, args[1:], kwargs)
self._do_test_rnn_packed_sequence(m, input, args[1:], kwargs, atol=atol, rtol=rtol)
else:
self._do_test(m, input, args[1:], kwargs, batch_first=batch_first)
self._do_test(m, input, args[1:], kwargs, batch_first=batch_first, atol=atol, rtol=rtol)
def test_per_sample_api_failing(self):
module = nn.Linear(10, 10)

View File

@ -2170,7 +2170,7 @@ def update_hidden_for_packed_reverse(
def one_layer_rnn_data(
inp, hidden, params, has_biases, nonlinearity, batch_sizes, reverse=False
inp, hidden, params, has_biases, hidden_fn, batch_sizes, reverse=False
):
ih_weight = params[0]
hh_weight = params[1]
@ -2200,8 +2200,7 @@ def one_layer_rnn_data(
cur_hidden, last_batch_size, i, hiddens
)
inp = F.linear(inp, ih_weight, ih_bias)
cur_hidden = nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + inp)
cur_hidden = hidden_fn(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
last_batch_size = i
step_output.append(cur_hidden)
@ -2216,7 +2215,22 @@ def one_layer_rnn_data(
return out, hidden_out
def one_layer_rnn(inp, hidden, params, has_biases, nonlinearity, reverse=False):
def rnn_cell(nonlinearity):
def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
return inner
def rnn_cell_data(nonlinearity):
def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
i = F.linear(i, ih_weight, ih_bias)
return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
return inner
def one_layer_rnn(inp, hidden, params, has_biases, hidden_fn, reverse=False):
ih_weight = params[0]
hh_weight = params[1]
ih_bias = params[2] if has_biases else None
@ -2226,8 +2240,8 @@ def one_layer_rnn(inp, hidden, params, has_biases, nonlinearity, reverse=False):
precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
cur_hidden = hidden.unsqueeze(0)
step_output = []
for inp in precomputed_input:
cur_hidden = nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + inp)
for i in precomputed_input:
cur_hidden = hidden_fn(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
step_output.append(cur_hidden)
if reverse:
@ -2305,7 +2319,7 @@ def rnn_tanh_input(
train,
bidirectional,
batch_first,
partial(one_layer_rnn, nonlinearity=torch.tanh),
partial(one_layer_rnn, hidden_fn=rnn_cell(torch.tanh)),
)
return out, torch.stack(final_hiddens, 0)
@ -2336,7 +2350,7 @@ def rnn_relu_input(
train,
bidirectional,
batch_first,
partial(one_layer_rnn, nonlinearity=torch.relu),
partial(one_layer_rnn, hidden_fn=rnn_cell(torch.relu)),
)
return out, torch.stack(final_hiddens, 0)
@ -2367,7 +2381,11 @@ def rnn_relu_data(
train,
bidirectional,
False,
partial(one_layer_rnn_data, batch_sizes=batch_sizes, nonlinearity=torch.relu),
partial(
one_layer_rnn_data,
batch_sizes=batch_sizes,
hidden_fn=rnn_cell_data(torch.relu),
),
)
return out, torch.stack(final_hiddens, 0)
@ -2398,7 +2416,11 @@ def rnn_tanh_data(
train,
bidirectional,
False,
partial(one_layer_rnn_data, batch_sizes=batch_sizes, nonlinearity=torch.tanh),
partial(
one_layer_rnn_data,
batch_sizes=batch_sizes,
hidden_fn=rnn_cell_data(torch.tanh),
),
)
return out, torch.stack(final_hiddens, 0)
@ -2573,6 +2595,84 @@ def lstm_data_impl(
return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
def gru_cell(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
chunked_igates = inp.chunk(3, 1)
chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 2)
reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
return (cur_hidden - new_gate) * input_gate + new_gate
def gru_cell_data(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
chunked_igates = F.linear(inp, ih_weight, ih_bias).chunk(3, 1)
chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 1)
reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
return (cur_hidden - new_gate) * input_gate + new_gate
@register_decomposition(aten.gru.data)
@aten.gru.data.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.gru.data.py_impl(DispatchKey.Autograd)
def gru_impl_data(
data,
batch_sizes,
hx,
params,
has_biases,
num_layers,
dropout,
train,
bidirectional,
):
params = gather_params(params, has_biases, False)
out, final_hiddens = _rnn_helper(
data,
hx.unbind(0),
params,
has_biases,
num_layers,
dropout,
train,
bidirectional,
False,
partial(one_layer_rnn_data, batch_sizes=batch_sizes, hidden_fn=gru_cell_data),
)
return out, torch.stack(final_hiddens, 0)
@register_decomposition(aten.gru.input)
@aten.gru.input.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.gru.input.py_impl(DispatchKey.Autograd)
def gru_impl(
input,
hx,
params,
has_biases,
num_layers,
dropout,
train,
bidirectional,
batch_first,
):
params = gather_params(params, has_biases, False)
out, final_hiddens = _rnn_helper(
input,
hx.unbind(0),
params,
has_biases,
num_layers,
dropout,
train,
bidirectional,
batch_first,
partial(one_layer_rnn, hidden_fn=gru_cell),
)
return out, torch.stack(final_hiddens, 0)
@register_decomposition(aten.upsample_bilinear2d.vec)
@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd)

View File

@ -19,6 +19,7 @@ expanded_weights_rnn_decomps = {
torch.rnn_relu: (decomposition_table[aten.rnn_relu.input], decomposition_table[aten.rnn_relu.data]),
torch.rnn_tanh: (decomposition_table[aten.rnn_tanh.input], decomposition_table[aten.rnn_tanh.data]),
torch.lstm: (decomposition_table[aten.lstm.input], decomposition_table[aten.lstm.data]),
torch.gru: (decomposition_table[aten.gru.input], decomposition_table[aten.gru.data]),
}
# all of the RNN decomps run linear with the batch dimension second, even if batch_first was set