mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Fix missing initialization of FastSpeech2Conformer
(#39689)
* fix missing initialization of FastSpeech2Conformer * switch order and reactivate tests --------- Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
@ -966,14 +966,18 @@ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel):
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.LayerNorm)):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, std=1.0 / math.sqrt(module.weight.size(1)))
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
key = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
||||
nn.init.uniform_(module.bias, a=-key, b=key)
|
||||
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_()
|
||||
if module.padding_idx is not None:
|
||||
|
@ -28,7 +28,6 @@ from transformers.testing_utils import (
|
||||
Expectations,
|
||||
require_g2p_en,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@ -123,7 +122,6 @@ class FastSpeech2ConformerModelTester:
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
@require_torch
|
||||
class FastSpeech2ConformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FastSpeech2ConformerModel,) if is_torch_available() else ()
|
||||
@ -561,7 +559,6 @@ class FastSpeech2ConformerWithHifiGanTester:
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
@require_torch
|
||||
class FastSpeech2ConformerWithHifiGanTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FastSpeech2ConformerWithHifiGan,) if is_torch_available() else ()
|
||||
|
Reference in New Issue
Block a user