Fix missing initialization of FastSpeech2Conformer (#39689)

* fix missing initialization of FastSpeech2Conformer

* switch order and reactivate tests

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
This commit is contained in:
BUI Van Tuan
2025-07-28 10:47:39 +02:00
committed by GitHub
parent a6393e7d28
commit 6a61e16626
2 changed files with 7 additions and 6 deletions

View File

@ -966,14 +966,18 @@ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.LayerNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=1.0 / math.sqrt(module.weight.size(1)))
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
key = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
nn.init.uniform_(module.bias, a=-key, b=key)
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_()
if module.padding_idx is not None:

View File

@ -28,7 +28,6 @@ from transformers.testing_utils import (
Expectations,
require_g2p_en,
require_torch,
require_torch_accelerator,
slow,
torch_device,
)
@ -123,7 +122,6 @@ class FastSpeech2ConformerModelTester:
return config, inputs_dict
@require_torch_accelerator
@require_torch
class FastSpeech2ConformerModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (FastSpeech2ConformerModel,) if is_torch_available() else ()
@ -561,7 +559,6 @@ class FastSpeech2ConformerWithHifiGanTester:
return config, inputs_dict
@require_torch_accelerator
@require_torch
class FastSpeech2ConformerWithHifiGanTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (FastSpeech2ConformerWithHifiGan,) if is_torch_available() else ()