[decompositions] GRU decompositon with and without packed sequence (#91466)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91466
Approved by: https://github.com/zou3519
This commit is contained in:
lezcano
2023-02-08 09:36:21 +00:00
committed by PyTorch MergeBot
parent 5a7c1b7894
commit fe0e28ab87
4 changed files with 120 additions and 18 deletions

View File

@ -387,7 +387,7 @@ class TestExpandedWeightFunctional(TestCase):
F.group_norm(inp, 2) # 5 is not divisible by 2
class TestExpandedWeightModule(TestCase):
def _do_test(self, module, input, args=None, kwargs=None, batch_first=True):
def _do_test(self, module, input, args=None, kwargs=None, batch_first=True, atol=None, rtol=None):
args = args or ()
kwargs = kwargs or {}
@ -432,7 +432,7 @@ class TestExpandedWeightModule(TestCase):
if not batch_first:
expected_grads[-1] = expected_grads[-1].transpose(0, 1)
self.assertEqual(actual_res, expected_res)
[self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
[self.assertEqual(actual, expected, atol=atol, rtol=rtol) for (actual, expected) in zip(actual_grads, expected_grads)]
def _do_test_multi_input(self, module, input):
class TestModule(nn.Module):
@ -475,7 +475,7 @@ class TestExpandedWeightModule(TestCase):
expected_grads = tuple(expected_grad for expected_grad in expected_grads if expected_grad is not None)
assert [self.assertEqual(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads)]
def _do_test_rnn_packed_sequence(self, module, input, args=None, kwargs=None):
def _do_test_rnn_packed_sequence(self, module, input, args=None, kwargs=None, atol=None, rtol=None):
args = args if args is not None else ()
kwargs = kwargs if kwargs is not None else {}
@ -510,9 +510,9 @@ class TestExpandedWeightModule(TestCase):
expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
self.assertEqual(actual_res, expected_res)
[self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
[self.assertEqual(actual, expected, atol=atol, rtol=rtol) for (actual, expected) in zip(actual_grads, expected_grads)]
@modules(filter(lambda m_info: m_info.module_cls in (torch.nn.RNN, torch.nn.LSTM), module_db))
@modules(filter(lambda m_info: m_info.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), module_db))
def test_module(self, device, dtype, module_info, training):
class RNNWrapper(torch.nn.Module):
def __init__(self, m_cons, args, kwargs):
@ -531,6 +531,7 @@ class TestExpandedWeightModule(TestCase):
module_cls = module_info.module_cls
atol, rtol = (1e-4, 1e-5) if module_cls == torch.nn.GRU and dtype == torch.float32 else (None, None)
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True, training=training, with_packed_sequence=True)
for module_input in module_inputs:
@ -562,9 +563,9 @@ class TestExpandedWeightModule(TestCase):
args[1] = h
if isinstance(input, torch.nn.utils.rnn.PackedSequence):
self._do_test_rnn_packed_sequence(m, input, args[1:], kwargs)
self._do_test_rnn_packed_sequence(m, input, args[1:], kwargs, atol=atol, rtol=rtol)
else:
self._do_test(m, input, args[1:], kwargs, batch_first=batch_first)
self._do_test(m, input, args[1:], kwargs, batch_first=batch_first, atol=atol, rtol=rtol)
def test_per_sample_api_failing(self):
module = nn.Linear(10, 10)