mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[decompositions] GRU decompositon with and without packed sequence (#91466)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91466 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
5a7c1b7894
commit
fe0e28ab87
@ -422,7 +422,7 @@ class TestDecomp(TestCase):
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
||||
@suppress_warnings
|
||||
# only tests RNNs since we have py dispsatcher decomps for them
|
||||
@modules(filter(lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM), module_db))
|
||||
@modules(filter(lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), module_db))
|
||||
def test_rnn_decomp_module(self, device, dtype, module_info, training):
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
|
@ -387,7 +387,7 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
F.group_norm(inp, 2) # 5 is not divisible by 2
|
||||
|
||||
class TestExpandedWeightModule(TestCase):
|
||||
def _do_test(self, module, input, args=None, kwargs=None, batch_first=True):
|
||||
def _do_test(self, module, input, args=None, kwargs=None, batch_first=True, atol=None, rtol=None):
|
||||
args = args or ()
|
||||
kwargs = kwargs or {}
|
||||
|
||||
@ -432,7 +432,7 @@ class TestExpandedWeightModule(TestCase):
|
||||
if not batch_first:
|
||||
expected_grads[-1] = expected_grads[-1].transpose(0, 1)
|
||||
self.assertEqual(actual_res, expected_res)
|
||||
[self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
[self.assertEqual(actual, expected, atol=atol, rtol=rtol) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
def _do_test_multi_input(self, module, input):
|
||||
class TestModule(nn.Module):
|
||||
@ -475,7 +475,7 @@ class TestExpandedWeightModule(TestCase):
|
||||
expected_grads = tuple(expected_grad for expected_grad in expected_grads if expected_grad is not None)
|
||||
assert [self.assertEqual(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
def _do_test_rnn_packed_sequence(self, module, input, args=None, kwargs=None):
|
||||
def _do_test_rnn_packed_sequence(self, module, input, args=None, kwargs=None, atol=None, rtol=None):
|
||||
args = args if args is not None else ()
|
||||
kwargs = kwargs if kwargs is not None else {}
|
||||
|
||||
@ -510,9 +510,9 @@ class TestExpandedWeightModule(TestCase):
|
||||
|
||||
expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
|
||||
self.assertEqual(actual_res, expected_res)
|
||||
[self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
[self.assertEqual(actual, expected, atol=atol, rtol=rtol) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
@modules(filter(lambda m_info: m_info.module_cls in (torch.nn.RNN, torch.nn.LSTM), module_db))
|
||||
@modules(filter(lambda m_info: m_info.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), module_db))
|
||||
def test_module(self, device, dtype, module_info, training):
|
||||
class RNNWrapper(torch.nn.Module):
|
||||
def __init__(self, m_cons, args, kwargs):
|
||||
@ -531,6 +531,7 @@ class TestExpandedWeightModule(TestCase):
|
||||
|
||||
|
||||
module_cls = module_info.module_cls
|
||||
atol, rtol = (1e-4, 1e-5) if module_cls == torch.nn.GRU and dtype == torch.float32 else (None, None)
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=True, training=training, with_packed_sequence=True)
|
||||
for module_input in module_inputs:
|
||||
@ -562,9 +563,9 @@ class TestExpandedWeightModule(TestCase):
|
||||
args[1] = h
|
||||
|
||||
if isinstance(input, torch.nn.utils.rnn.PackedSequence):
|
||||
self._do_test_rnn_packed_sequence(m, input, args[1:], kwargs)
|
||||
self._do_test_rnn_packed_sequence(m, input, args[1:], kwargs, atol=atol, rtol=rtol)
|
||||
else:
|
||||
self._do_test(m, input, args[1:], kwargs, batch_first=batch_first)
|
||||
self._do_test(m, input, args[1:], kwargs, batch_first=batch_first, atol=atol, rtol=rtol)
|
||||
|
||||
def test_per_sample_api_failing(self):
|
||||
module = nn.Linear(10, 10)
|
||||
|
@ -2170,7 +2170,7 @@ def update_hidden_for_packed_reverse(
|
||||
|
||||
|
||||
def one_layer_rnn_data(
|
||||
inp, hidden, params, has_biases, nonlinearity, batch_sizes, reverse=False
|
||||
inp, hidden, params, has_biases, hidden_fn, batch_sizes, reverse=False
|
||||
):
|
||||
ih_weight = params[0]
|
||||
hh_weight = params[1]
|
||||
@ -2200,8 +2200,7 @@ def one_layer_rnn_data(
|
||||
cur_hidden, last_batch_size, i, hiddens
|
||||
)
|
||||
|
||||
inp = F.linear(inp, ih_weight, ih_bias)
|
||||
cur_hidden = nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + inp)
|
||||
cur_hidden = hidden_fn(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
|
||||
last_batch_size = i
|
||||
step_output.append(cur_hidden)
|
||||
|
||||
@ -2216,7 +2215,22 @@ def one_layer_rnn_data(
|
||||
return out, hidden_out
|
||||
|
||||
|
||||
def one_layer_rnn(inp, hidden, params, has_biases, nonlinearity, reverse=False):
|
||||
def rnn_cell(nonlinearity):
|
||||
def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
|
||||
return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def rnn_cell_data(nonlinearity):
|
||||
def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
|
||||
i = F.linear(i, ih_weight, ih_bias)
|
||||
return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def one_layer_rnn(inp, hidden, params, has_biases, hidden_fn, reverse=False):
|
||||
ih_weight = params[0]
|
||||
hh_weight = params[1]
|
||||
ih_bias = params[2] if has_biases else None
|
||||
@ -2226,8 +2240,8 @@ def one_layer_rnn(inp, hidden, params, has_biases, nonlinearity, reverse=False):
|
||||
precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
|
||||
cur_hidden = hidden.unsqueeze(0)
|
||||
step_output = []
|
||||
for inp in precomputed_input:
|
||||
cur_hidden = nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + inp)
|
||||
for i in precomputed_input:
|
||||
cur_hidden = hidden_fn(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
|
||||
step_output.append(cur_hidden)
|
||||
|
||||
if reverse:
|
||||
@ -2305,7 +2319,7 @@ def rnn_tanh_input(
|
||||
train,
|
||||
bidirectional,
|
||||
batch_first,
|
||||
partial(one_layer_rnn, nonlinearity=torch.tanh),
|
||||
partial(one_layer_rnn, hidden_fn=rnn_cell(torch.tanh)),
|
||||
)
|
||||
return out, torch.stack(final_hiddens, 0)
|
||||
|
||||
@ -2336,7 +2350,7 @@ def rnn_relu_input(
|
||||
train,
|
||||
bidirectional,
|
||||
batch_first,
|
||||
partial(one_layer_rnn, nonlinearity=torch.relu),
|
||||
partial(one_layer_rnn, hidden_fn=rnn_cell(torch.relu)),
|
||||
)
|
||||
return out, torch.stack(final_hiddens, 0)
|
||||
|
||||
@ -2367,7 +2381,11 @@ def rnn_relu_data(
|
||||
train,
|
||||
bidirectional,
|
||||
False,
|
||||
partial(one_layer_rnn_data, batch_sizes=batch_sizes, nonlinearity=torch.relu),
|
||||
partial(
|
||||
one_layer_rnn_data,
|
||||
batch_sizes=batch_sizes,
|
||||
hidden_fn=rnn_cell_data(torch.relu),
|
||||
),
|
||||
)
|
||||
return out, torch.stack(final_hiddens, 0)
|
||||
|
||||
@ -2398,7 +2416,11 @@ def rnn_tanh_data(
|
||||
train,
|
||||
bidirectional,
|
||||
False,
|
||||
partial(one_layer_rnn_data, batch_sizes=batch_sizes, nonlinearity=torch.tanh),
|
||||
partial(
|
||||
one_layer_rnn_data,
|
||||
batch_sizes=batch_sizes,
|
||||
hidden_fn=rnn_cell_data(torch.tanh),
|
||||
),
|
||||
)
|
||||
return out, torch.stack(final_hiddens, 0)
|
||||
|
||||
@ -2573,6 +2595,84 @@ def lstm_data_impl(
|
||||
return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
|
||||
|
||||
|
||||
def gru_cell(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
|
||||
chunked_igates = inp.chunk(3, 1)
|
||||
chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 2)
|
||||
reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
|
||||
input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
|
||||
new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
|
||||
return (cur_hidden - new_gate) * input_gate + new_gate
|
||||
|
||||
|
||||
def gru_cell_data(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
|
||||
chunked_igates = F.linear(inp, ih_weight, ih_bias).chunk(3, 1)
|
||||
chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 1)
|
||||
reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
|
||||
input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
|
||||
new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
|
||||
return (cur_hidden - new_gate) * input_gate + new_gate
|
||||
|
||||
|
||||
@register_decomposition(aten.gru.data)
|
||||
@aten.gru.data.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten.gru.data.py_impl(DispatchKey.Autograd)
|
||||
def gru_impl_data(
|
||||
data,
|
||||
batch_sizes,
|
||||
hx,
|
||||
params,
|
||||
has_biases,
|
||||
num_layers,
|
||||
dropout,
|
||||
train,
|
||||
bidirectional,
|
||||
):
|
||||
params = gather_params(params, has_biases, False)
|
||||
out, final_hiddens = _rnn_helper(
|
||||
data,
|
||||
hx.unbind(0),
|
||||
params,
|
||||
has_biases,
|
||||
num_layers,
|
||||
dropout,
|
||||
train,
|
||||
bidirectional,
|
||||
False,
|
||||
partial(one_layer_rnn_data, batch_sizes=batch_sizes, hidden_fn=gru_cell_data),
|
||||
)
|
||||
return out, torch.stack(final_hiddens, 0)
|
||||
|
||||
|
||||
@register_decomposition(aten.gru.input)
|
||||
@aten.gru.input.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten.gru.input.py_impl(DispatchKey.Autograd)
|
||||
def gru_impl(
|
||||
input,
|
||||
hx,
|
||||
params,
|
||||
has_biases,
|
||||
num_layers,
|
||||
dropout,
|
||||
train,
|
||||
bidirectional,
|
||||
batch_first,
|
||||
):
|
||||
params = gather_params(params, has_biases, False)
|
||||
out, final_hiddens = _rnn_helper(
|
||||
input,
|
||||
hx.unbind(0),
|
||||
params,
|
||||
has_biases,
|
||||
num_layers,
|
||||
dropout,
|
||||
train,
|
||||
bidirectional,
|
||||
batch_first,
|
||||
partial(one_layer_rnn, hidden_fn=gru_cell),
|
||||
)
|
||||
return out, torch.stack(final_hiddens, 0)
|
||||
|
||||
|
||||
@register_decomposition(aten.upsample_bilinear2d.vec)
|
||||
@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd)
|
||||
|
@ -19,6 +19,7 @@ expanded_weights_rnn_decomps = {
|
||||
torch.rnn_relu: (decomposition_table[aten.rnn_relu.input], decomposition_table[aten.rnn_relu.data]),
|
||||
torch.rnn_tanh: (decomposition_table[aten.rnn_tanh.input], decomposition_table[aten.rnn_tanh.data]),
|
||||
torch.lstm: (decomposition_table[aten.lstm.input], decomposition_table[aten.lstm.data]),
|
||||
torch.gru: (decomposition_table[aten.gru.input], decomposition_table[aten.gru.data]),
|
||||
}
|
||||
|
||||
# all of the RNN decomps run linear with the batch dimension second, even if batch_first was set
|
||||
|
Reference in New Issue
Block a user