Compare commits

...

4 Commits

Author SHA1 Message Date
60a1b0eb8e check 2024-02-06 12:55:36 +01:00
72e2bcc5ab check 2024-02-06 12:20:31 +01:00
4598609450 check 2024-02-06 12:15:08 +01:00
7b1b607991 check 2024-02-06 11:56:21 +01:00
2 changed files with 27 additions and 3 deletions

View File

@ -610,7 +610,12 @@ class FastSpeech2ConformerConvolutionModule(nn.Module):
hidden_states = nn.functional.glu(hidden_states, dim=1)
# 1D Depthwise Conv
hidden_states = self.depthwise_conv(hidden_states)
try:
hidden_states = self.depthwise_conv(hidden_states)
except:
breakpoint()
print(hidden_states)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states * torch.sigmoid(hidden_states)

View File

@ -117,7 +117,7 @@ class FastSpeech2ConformerModelTester:
return config, inputs_dict
@require_torch_accelerator
@require_torch
class FastSpeech2ConformerModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (FastSpeech2ConformerModel,) if is_torch_available() else ()
@ -275,6 +275,10 @@ class FastSpeech2ConformerModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(decoder_attentions.grad)
def test_attention_outputs(self):
for i in range(100):
self._test_attention_outputs()
def _test_attention_outputs(self):
"""
Custom `test_attention_outputs` since FastSpeech2Conformer does not output cross attentions, has variable
decoder attention shape, and uniquely outputs energy, pitch, and durations.
@ -303,7 +307,22 @@ class FastSpeech2ConformerModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
inputs = {'input_ids': torch.Tensor([[20, 1, 21, 17, 14, 23, 20],
[14, 21, 13, 16, 12, 5, 2],
[ 5, 13, 7, 11, 16, 13, 12],
[ 1, 3, 12, 15, 14, 17, 23],
[ 1, 21, 7, 6, 3, 13, 10],
[20, 15, 17, 13, 19, 13, 16],
[10, 9, 17, 0, 3, 18, 1],
[21, 14, 5, 20, 8, 7, 21],
[ 9, 11, 0, 20, 3, 19, 23],
[11, 18, 10, 19, 22, 22, 4],
[ 4, 9, 14, 21, 19, 4, 19],
[ 4, 0, 11, 17, 21, 19, 6],
[10, 10, 13, 15, 10, 7, 5]]).to(dtype=torch.int64), 'output_attentions': True, 'output_hidden_states': False}
outputs = model(**inputs)
continue
encoder_attentions = outputs.encoder_attentions
self.assertEqual(len(encoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(