mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[decompositions] GRU decompositon with and without packed sequence (#91466)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91466 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
5a7c1b7894
commit
fe0e28ab87
@ -387,7 +387,7 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
F.group_norm(inp, 2) # 5 is not divisible by 2
|
||||
|
||||
class TestExpandedWeightModule(TestCase):
|
||||
def _do_test(self, module, input, args=None, kwargs=None, batch_first=True):
|
||||
def _do_test(self, module, input, args=None, kwargs=None, batch_first=True, atol=None, rtol=None):
|
||||
args = args or ()
|
||||
kwargs = kwargs or {}
|
||||
|
||||
@ -432,7 +432,7 @@ class TestExpandedWeightModule(TestCase):
|
||||
if not batch_first:
|
||||
expected_grads[-1] = expected_grads[-1].transpose(0, 1)
|
||||
self.assertEqual(actual_res, expected_res)
|
||||
[self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
[self.assertEqual(actual, expected, atol=atol, rtol=rtol) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
def _do_test_multi_input(self, module, input):
|
||||
class TestModule(nn.Module):
|
||||
@ -475,7 +475,7 @@ class TestExpandedWeightModule(TestCase):
|
||||
expected_grads = tuple(expected_grad for expected_grad in expected_grads if expected_grad is not None)
|
||||
assert [self.assertEqual(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
def _do_test_rnn_packed_sequence(self, module, input, args=None, kwargs=None):
|
||||
def _do_test_rnn_packed_sequence(self, module, input, args=None, kwargs=None, atol=None, rtol=None):
|
||||
args = args if args is not None else ()
|
||||
kwargs = kwargs if kwargs is not None else {}
|
||||
|
||||
@ -510,9 +510,9 @@ class TestExpandedWeightModule(TestCase):
|
||||
|
||||
expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)]
|
||||
self.assertEqual(actual_res, expected_res)
|
||||
[self.assertEqual(actual, expected) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
[self.assertEqual(actual, expected, atol=atol, rtol=rtol) for (actual, expected) in zip(actual_grads, expected_grads)]
|
||||
|
||||
@modules(filter(lambda m_info: m_info.module_cls in (torch.nn.RNN, torch.nn.LSTM), module_db))
|
||||
@modules(filter(lambda m_info: m_info.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), module_db))
|
||||
def test_module(self, device, dtype, module_info, training):
|
||||
class RNNWrapper(torch.nn.Module):
|
||||
def __init__(self, m_cons, args, kwargs):
|
||||
@ -531,6 +531,7 @@ class TestExpandedWeightModule(TestCase):
|
||||
|
||||
|
||||
module_cls = module_info.module_cls
|
||||
atol, rtol = (1e-4, 1e-5) if module_cls == torch.nn.GRU and dtype == torch.float32 else (None, None)
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=True, training=training, with_packed_sequence=True)
|
||||
for module_input in module_inputs:
|
||||
@ -562,9 +563,9 @@ class TestExpandedWeightModule(TestCase):
|
||||
args[1] = h
|
||||
|
||||
if isinstance(input, torch.nn.utils.rnn.PackedSequence):
|
||||
self._do_test_rnn_packed_sequence(m, input, args[1:], kwargs)
|
||||
self._do_test_rnn_packed_sequence(m, input, args[1:], kwargs, atol=atol, rtol=rtol)
|
||||
else:
|
||||
self._do_test(m, input, args[1:], kwargs, batch_first=batch_first)
|
||||
self._do_test(m, input, args[1:], kwargs, batch_first=batch_first, atol=atol, rtol=rtol)
|
||||
|
||||
def test_per_sample_api_failing(self):
|
||||
module = nn.Linear(10, 10)
|
||||
|
Reference in New Issue
Block a user