mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Add Parakeet (#39062)
* first commit Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * update to handle masking for bs>1 Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * Add tests and docs Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * update model ids Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * update docs and improve style Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * update librosa location Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * import guard torch too Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * ruff code checks fix Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * ruff format check Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * updated to parakeet names Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * update script Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * Add tokenizer decoding Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * Remove other model dependency Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * clean tests Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * fix tests Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * linting Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * fix ruff lint warnings Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * move to seperate folders Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * add parakeet ctc model code Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * simplify encoder structure Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * update documentation Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * add parakeet to toctree Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * fix tests Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * add parakeet doc Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * Address comments Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * Update featurizer to compute lens directly Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * fix ruff tests Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * fix encoding format Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * fix minor ctc decoding Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> * revert modular_model_converter.py changes * revert check_config_attributes.py changes * refactor: fastconformer & parakeet_ctc -> parakeet * modeling update * test update * propagate feature extractor updates * propagate doc changes * propagate doc changes * propagate tokenization changes * propagate conversion changes * remove fastconformer tests * remove modular * update processor * update processor * tset update * diverse fixes * 100% macthing greedy batched * Update conversion script. * Refactor docs. * Reafactor auto loading. * Refactor and fix tokenization and processing. * Update integration test. * Modeling fixes: - ensure correct attention mask shape - ensure layer drop returns valid output - correct blank token ID when computing CTC loss * Format and repo consistency. * Update model doc. * Fix feature extraction tests. * Fix (most) tokenizer tests. * Add pipeline example. * Fixes * Use eager_attention_forward from Llama. * Small tweaks. * Replace Sequential with ModuleList * Add check if not all layers copied * Clean tokenizer. * Standardize FastSpeech2ConformerConvolutionModule for Parakeet. * Switch to modular for modeling and processing. * Add processor tests. * Fix modeling tests. * Formating and docstrings. * Add `return_attention_mask` like other feature extractors. * clean up after merging main. * nits on modeling * configuration update * nit * simplification: use PretrainedTokenizerFast, simplify processor * add dtype arg to mel_filter_bank * feature extraction: simplify! * modeling update * change to ParakeetTokenizerFast * correct attention mask handling * auto update * proc update * test update * feature extraction fixes * modeling update * conversion script update * udpate tests feature integration * update tokenization and tests * processor tests * revert audio_utils * config docstring update * blank_token -> pad_token * modeling udpate * doc update * fix tests * fix test * fix tests * address review comments * add comment * add comment * explicitly not support flash * atttention straightforward masking * fix * tokenizer update: skipping blank tokens by default * doc update * fix max_positions_embeddings handling * nits * change atol faeture extraction integration tests * doc update + fix loss * doc update * nit * update integration test for A10 * repo id name * nit --------- Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com> Co-authored-by: Eustache Le Bihan <eulebihan@gmail.com> Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com> Co-authored-by: Eric B <ebezzam@gmail.com>
This commit is contained in:
@ -935,6 +935,8 @@
|
||||
title: MusicGen
|
||||
- local: model_doc/musicgen_melody
|
||||
title: MusicGen Melody
|
||||
- local: model_doc/parakeet
|
||||
title: Parakeet
|
||||
- local: model_doc/pop2piano
|
||||
title: Pop2Piano
|
||||
- local: model_doc/seamless_m4t
|
||||
|
220
docs/source/en/model_doc/parakeet.md
Normal file
220
docs/source/en/model_doc/parakeet.md
Normal file
@ -0,0 +1,220 @@
|
||||
<!--Copyright 2025 The NVIDIA NeMo Team and The HuggingFace Inc. team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
|
||||
# Parakeet
|
||||
|
||||
## Overview
|
||||
|
||||
Parakeet models, [introduced by NVIDIA NeMo](https://developer.nvidia.com/blog/pushing-the-boundaries-of-speech-recognition-with-nemo-parakeet-asr-models/), are models that combine a [Fast Conformer](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/asr/models.html#fast-conformer) encoder with connectionist temporal classification (CTC), recurrent neural network transducer (RNNT) or token and duration transducer (TDT) decoder for automatic speech recognition.
|
||||
|
||||
**Model Architecture**
|
||||
- **Fast Conformer Encoder**: A linearly scalable Conformer architecture that processes mel-spectrogram features and reduces sequence length through subsampling. This is more efficient version of the Conformer Encoder found in [FastSpeech2Conformer](./fastspeech2_conformer.md) (see [`ParakeetEncoder`] for the encoder implementation and details).
|
||||
- [**ParakeetForCTC**](#parakeetforctc): a Fast Conformer Encoder + a CTC decoder
|
||||
- **CTC Decoder**: Simple but effective decoder consisting of:
|
||||
- 1D convolution projection from encoder hidden size to vocabulary size (for optimal NeMo compatibility).
|
||||
- CTC loss computation for training.
|
||||
- Greedy CTC decoding for inference.
|
||||
|
||||
The original implementation can be found in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo).
|
||||
Model checkpoints are to be found under [the NVIDIA organization](https://huggingface.co/nvidia/models?search=parakeet).
|
||||
|
||||
This model was contributed by [Nithin Rao Koluguri](https://huggingface.co/nithinraok), [Eustache Le Bihan](https://huggingface.co/eustlb) and [Eric Bezzam](https://huggingface.co/bezzam).
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic usage
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline("automatic-speech-recognition", model="nvidia/parakeet-ctc-1.1b")
|
||||
out = pipe("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3")
|
||||
print(out)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCTC, AutoProcessor
|
||||
from datasets import load_dataset, Audio
|
||||
import torch
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
|
||||
model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map=device)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
|
||||
speech_samples = [el['array'] for el in ds["audio"][:5]]
|
||||
|
||||
inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate)
|
||||
inputs.to(model.device, dtype=model.dtype)
|
||||
outputs = model.generate(**inputs)
|
||||
print(processor.batch_decode(outputs))
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Making The Model Go Brrr
|
||||
|
||||
Parakeet supports full-graph compilation with CUDA graphs! This optimization is most effective when you know the maximum audio length you want to transcribe. The key idea is using static input shapes to avoid recompilation. For example, if you know your audio will be under 30 seconds, you can use the processor to pad all inputs to 30 seconds, preparing consistent input features and attention masks. See the example below!
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCTC, AutoProcessor
|
||||
from datasets import load_dataset, Audio
|
||||
import torch
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
|
||||
model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map=device)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
|
||||
speech_samples = [el['array'] for el in ds["audio"][:5]]
|
||||
|
||||
# Compile the generate method with fullgraph and CUDA graphs
|
||||
model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
|
||||
|
||||
# let's define processor kwargs to pad to 30 seconds
|
||||
processor_kwargs = {
|
||||
"padding": "max_length",
|
||||
"max_length": 30 * processor.feature_extractor.sampling_rate,
|
||||
}
|
||||
|
||||
# Define a timing context using CUDA events
|
||||
class TimerContext:
|
||||
def __init__(self, name="Execution"):
|
||||
self.name = name
|
||||
self.start_event = None
|
||||
self.end_event = None
|
||||
|
||||
def __enter__(self):
|
||||
# Use CUDA events for more accurate GPU timing
|
||||
self.start_event = torch.cuda.Event(enable_timing=True)
|
||||
self.end_event = torch.cuda.Event(enable_timing=True)
|
||||
self.start_event.record()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
elapsed_time = self.start_event.elapsed_time(self.end_event) / 1000.0
|
||||
print(f"{self.name} time: {elapsed_time:.4f} seconds")
|
||||
|
||||
|
||||
inputs = processor(speech_samples[0], **processor_kwargs)
|
||||
inputs.to(device, dtype=model.dtype)
|
||||
print("\n" + "="*50)
|
||||
print("First generation - compiling...")
|
||||
# Generate with the compiled model
|
||||
with TimerContext("First generation"):
|
||||
outputs = model.generate(**inputs)
|
||||
print(processor.batch_decode(outputs))
|
||||
|
||||
inputs = processor(speech_samples[1], **processor_kwargs)
|
||||
inputs.to(device, dtype=model.dtype)
|
||||
print("\n" + "="*50)
|
||||
print("Second generation - recording CUDA graphs...")
|
||||
with TimerContext("Second generation"):
|
||||
outputs = model.generate(**inputs)
|
||||
print(processor.batch_decode(outputs))
|
||||
|
||||
inputs = processor(speech_samples[2], **processor_kwargs)
|
||||
inputs.to(device, dtype=model.dtype)
|
||||
print("\n" + "="*50)
|
||||
print("Third generation - fast !!!")
|
||||
with TimerContext("Third generation"):
|
||||
outputs = model.generate(**inputs)
|
||||
print(processor.batch_decode(outputs))
|
||||
|
||||
inputs = processor(speech_samples[3], **processor_kwargs)
|
||||
inputs.to(device, dtype=model.dtype)
|
||||
print("\n" + "="*50)
|
||||
print("Fourth generation - still fast !!!")
|
||||
with TimerContext("Fourth generation"):
|
||||
outputs = model.generate(**inputs)
|
||||
print(processor.batch_decode(outputs))
|
||||
```
|
||||
|
||||
### Training
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCTC, AutoProcessor
|
||||
from datasets import load_dataset, Audio
|
||||
import torch
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
|
||||
model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map=device)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
|
||||
speech_samples = [el['array'] for el in ds["audio"][:5]]
|
||||
text_samples = [el for el in ds["text"][:5]]
|
||||
|
||||
# passing `text` to the processor will prepare inputs' `labels` key
|
||||
inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate)
|
||||
inputs.to(device, dtype=model.dtype)
|
||||
|
||||
outputs = model(**inputs)
|
||||
outputs.loss.backward()
|
||||
```
|
||||
|
||||
## ParakeetTokenizerFast
|
||||
|
||||
[[autodoc]] ParakeetTokenizerFast
|
||||
|
||||
## ParakeetFeatureExtractor
|
||||
|
||||
[[autodoc]] ParakeetFeatureExtractor
|
||||
- __call__
|
||||
|
||||
## ParakeetProcessor
|
||||
|
||||
[[autodoc]] ParakeetProcessor
|
||||
- __call__
|
||||
- batch_decode
|
||||
- decode
|
||||
|
||||
## ParakeetEncoderConfig
|
||||
|
||||
[[autodoc]] ParakeetEncoderConfig
|
||||
|
||||
## ParakeetCTCConfig
|
||||
|
||||
[[autodoc]] ParakeetCTCConfig
|
||||
|
||||
## ParakeetEncoder
|
||||
|
||||
[[autodoc]] ParakeetEncoder
|
||||
|
||||
## ParakeetForCTC
|
||||
|
||||
[[autodoc]] ParakeetForCTC
|
||||
|
@ -1540,6 +1540,54 @@ class HeliumConverter(SpmConverter):
|
||||
)
|
||||
|
||||
|
||||
class ParakeetConverter(SpmConverter):
|
||||
handle_byte_fallback = True
|
||||
|
||||
def __init__(self, vocab_file=None, *args):
|
||||
self.vocab_file = vocab_file
|
||||
|
||||
requires_backends(self, "protobuf")
|
||||
|
||||
Converter.__init__(self, vocab_file)
|
||||
|
||||
model_pb2 = import_protobuf()
|
||||
m = model_pb2.ModelProto()
|
||||
with open(vocab_file, "rb") as f:
|
||||
m.ParseFromString(f.read())
|
||||
self.proto = m
|
||||
|
||||
def tokenizer(self, proto):
|
||||
vocab_scores = self.vocab(proto)
|
||||
|
||||
_, merges = self.SpmExtractor(self.vocab_file).extract(vocab_scores)
|
||||
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
|
||||
tokenizer = Tokenizer(
|
||||
BPE(
|
||||
bpe_vocab,
|
||||
merges,
|
||||
unk_token=proto.trainer_spec.unk_piece,
|
||||
fuse_unk=True,
|
||||
byte_fallback=self.handle_byte_fallback,
|
||||
dropout=None,
|
||||
)
|
||||
)
|
||||
|
||||
# Add user defined symbols and control tokens from sentencepiece model
|
||||
spm_added_tokens = [
|
||||
(id, p.piece, p.type == 3 or p.piece in self.special_tokens)
|
||||
for id, p in enumerate(proto.pieces)
|
||||
if p.type in [3, 4]
|
||||
]
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken(token, normalized=False, special=special)
|
||||
for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
|
||||
]
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
|
@ -253,6 +253,7 @@ if TYPE_CHECKING:
|
||||
from .owlv2 import *
|
||||
from .owlvit import *
|
||||
from .paligemma import *
|
||||
from .parakeet import *
|
||||
from .patchtsmixer import *
|
||||
from .patchtst import *
|
||||
from .pegasus import *
|
||||
|
@ -296,6 +296,8 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
||||
("owlv2", "Owlv2Config"),
|
||||
("owlvit", "OwlViTConfig"),
|
||||
("paligemma", "PaliGemmaConfig"),
|
||||
("parakeet_ctc", "ParakeetCTCConfig"),
|
||||
("parakeet_encoder", "ParakeetEncoderConfig"),
|
||||
("patchtsmixer", "PatchTSMixerConfig"),
|
||||
("patchtst", "PatchTSTConfig"),
|
||||
("pegasus", "PegasusConfig"),
|
||||
@ -745,6 +747,9 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
||||
("owlv2", "OWLv2"),
|
||||
("owlvit", "OWL-ViT"),
|
||||
("paligemma", "PaliGemma"),
|
||||
("parakeet", "Parakeet"),
|
||||
("parakeet_ctc", "Parakeet"),
|
||||
("parakeet_encoder", "ParakeetEncoder"),
|
||||
("patchtsmixer", "PatchTSMixer"),
|
||||
("patchtst", "PatchTST"),
|
||||
("pegasus", "Pegasus"),
|
||||
@ -984,6 +989,8 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
|
||||
("blip_2_qformer", "blip_2"),
|
||||
("fastspeech2_conformer_with_hifigan", "fastspeech2_conformer"),
|
||||
("perception_encoder", "perception_lm"),
|
||||
("parakeet_encoder", "parakeet"),
|
||||
("parakeet_ctc", "parakeet"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -81,6 +81,8 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
|
||||
("moshi", "EncodecFeatureExtractor"),
|
||||
("nat", "ViTFeatureExtractor"),
|
||||
("owlvit", "OwlViTFeatureExtractor"),
|
||||
("parakeet_ctc", "ParakeetFeatureExtractor"),
|
||||
("parakeet_encoder", "ParakeetFeatureExtractor"),
|
||||
("perceiver", "PerceiverFeatureExtractor"),
|
||||
("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
|
||||
("poolformer", "PoolFormerFeatureExtractor"),
|
||||
|
@ -295,6 +295,8 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("owlv2", "Owlv2Model"),
|
||||
("owlvit", "OwlViTModel"),
|
||||
("paligemma", "PaliGemmaModel"),
|
||||
("parakeet_ctc", "ParakeetForCTC"),
|
||||
("parakeet_encoder", "ParakeetEncoder"),
|
||||
("patchtsmixer", "PatchTSMixerModel"),
|
||||
("patchtst", "PatchTSTModel"),
|
||||
("pegasus", "PegasusModel"),
|
||||
@ -1601,6 +1603,7 @@ MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
|
||||
("data2vec-audio", "Data2VecAudioForCTC"),
|
||||
("hubert", "HubertForCTC"),
|
||||
("mctct", "MCTCTForCTC"),
|
||||
("parakeet_ctc", "ParakeetForCTC"),
|
||||
("sew", "SEWForCTC"),
|
||||
("sew-d", "SEWDForCTC"),
|
||||
("unispeech", "UniSpeechForCTC"),
|
||||
|
@ -502,6 +502,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
|
||||
("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
|
||||
("parakeet", ("ParakeetCTCTokenizer", None)),
|
||||
(
|
||||
"pegasus",
|
||||
(
|
||||
|
@ -21,6 +21,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_outputs import BaseModelOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, logging
|
||||
@ -472,24 +473,37 @@ class FastSpeech2ConformerAttention(nn.Module):
|
||||
|
||||
|
||||
class FastSpeech2ConformerConvolutionModule(nn.Module):
|
||||
def __init__(self, config: FastSpeech2ConformerConfig, module_config):
|
||||
def __init__(self, config: FastSpeech2ConformerConfig, module_config=None):
|
||||
"""
|
||||
Args:
|
||||
config (FastSpeech2ConformerConfig): Configuration for the model.
|
||||
module_config (dict): Configuration for the module (e.g., encoder or decoder).
|
||||
"""
|
||||
super().__init__()
|
||||
# kernel_size should be an odd number for 'SAME' padding
|
||||
channels = config.hidden_size
|
||||
kernel_size = module_config["kernel_size"]
|
||||
# kernel_size should be an odd number for 'SAME' padding
|
||||
if module_config is None:
|
||||
# e.g. using `ParakeetEncoderConfig` in src/transformers/models/parakeet/configuration_parakeet.py
|
||||
kernel_size = config.conv_kernel_size
|
||||
self.activation = ACT2FN[getattr(config, "hidden_act", "silu")]
|
||||
else:
|
||||
kernel_size = module_config["kernel_size"]
|
||||
self.activation = ACT2FN[module_config.get("activation", "silu")]
|
||||
self.padding = (kernel_size - 1) // 2
|
||||
self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=True
|
||||
channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states, attention_mask=None):
|
||||
"""
|
||||
Compute convolution module.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
|
||||
attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
|
||||
@ -503,12 +517,15 @@ class FastSpeech2ConformerConvolutionModule(nn.Module):
|
||||
# (batch_size, channel, dim)
|
||||
hidden_states = nn.functional.glu(hidden_states, dim=1)
|
||||
|
||||
# Apply padding mask before convolution
|
||||
if attention_mask is not None:
|
||||
all_masked_rows = torch.all(~attention_mask, dim=-1)
|
||||
hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
hidden_states = self.depthwise_conv(hidden_states)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
hidden_states = hidden_states * torch.sigmoid(hidden_states)
|
||||
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.pointwise_conv2(hidden_states)
|
||||
|
||||
return hidden_states.transpose(1, 2)
|
||||
|
29
src/transformers/models/parakeet/__init__.py
Normal file
29
src/transformers/models/parakeet/__init__.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import _LazyModule
|
||||
from ...utils.import_utils import define_import_structure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_parakeet import *
|
||||
from .feature_extraction_parakeet import *
|
||||
from .modeling_parakeet import *
|
||||
from .tokenization_parakeet_fast import *
|
||||
else:
|
||||
import sys
|
||||
|
||||
_file = globals()["__file__"]
|
||||
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
235
src/transformers/models/parakeet/configuration_parakeet.py
Normal file
235
src/transformers/models/parakeet/configuration_parakeet.py
Normal file
@ -0,0 +1,235 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Parakeet model configuration."""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ParakeetEncoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`ParakeetEncoder`]. It is used to instantiate a
|
||||
`ParakeetEncoder` model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimension of the layers and the hidden states.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 8):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler.
|
||||
attention_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use bias in the attention layers.
|
||||
conv_kernel_size (`int`, *optional*, defaults to 9):
|
||||
The kernel size of the convolution layers in the Conformer block.
|
||||
subsampling_factor (`int`, *optional*, defaults to 8):
|
||||
The factor by which the input sequence is subsampled.
|
||||
subsampling_conv_channels (`int`, *optional*, defaults to 256):
|
||||
The number of channels in the subsampling convolution layers.
|
||||
num_mel_bins (`int`, *optional*, defaults to 80):
|
||||
Number of mel features.
|
||||
subsampling_conv_kernel_size (`int`, *optional*, defaults to 3):
|
||||
The kernel size of the subsampling convolution layers.
|
||||
subsampling_conv_stride (`int`, *optional*, defaults to 2):
|
||||
The stride of the subsampling convolution layers.
|
||||
dropout (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
dropout_positions (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the positions in the input sequence.
|
||||
layerdrop (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the layers in the encoder.
|
||||
activation_dropout (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.1):
|
||||
The dropout ratio for the attention layers.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 5000):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
scale_input (`bool`, *optional*, defaults to `True`):
|
||||
Whether to scale the input embeddings.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import ParakeetEncoderModel, ParakeetEncoderConfig
|
||||
|
||||
>>> # Initializing a `ParakeetEncoder` configuration
|
||||
>>> configuration = ParakeetEncoderConfig()
|
||||
|
||||
>>> # Initializing a model from the configuration
|
||||
>>> model = ParakeetEncoderModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
|
||||
This configuration class is based on the ParakeetEncoder architecture from NVIDIA NeMo. You can find more details
|
||||
and pre-trained models at [nvidia/parakeet-ctc-1.1b](https://huggingface.co/nvidia/parakeet-ctc-1.1b).
|
||||
"""
|
||||
|
||||
model_type = "parakeet_encoder"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=24,
|
||||
num_attention_heads=8,
|
||||
intermediate_size=4096,
|
||||
hidden_act="silu",
|
||||
attention_bias=True,
|
||||
conv_kernel_size=9,
|
||||
subsampling_factor=8,
|
||||
subsampling_conv_channels=256,
|
||||
num_mel_bins=80,
|
||||
subsampling_conv_kernel_size=3,
|
||||
subsampling_conv_stride=2,
|
||||
dropout=0.1,
|
||||
dropout_positions=0.0,
|
||||
layerdrop=0.1,
|
||||
activation_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
max_position_embeddings=5000,
|
||||
scale_input=True,
|
||||
initializer_range=0.02,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
**kwargs,
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_attention_heads # LlamaAttention compatibility
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.attention_bias = attention_bias
|
||||
|
||||
if (conv_kernel_size - 1) % 2 != 0:
|
||||
raise ValueError(f"conv_kernel_size must be odd, got {conv_kernel_size}")
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
|
||||
self.subsampling_conv_kernel_size = subsampling_conv_kernel_size
|
||||
self.subsampling_conv_stride = subsampling_conv_stride
|
||||
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.subsampling_conv_channels = subsampling_conv_channels
|
||||
self.num_mel_bins = num_mel_bins
|
||||
|
||||
self.dropout = dropout
|
||||
self.dropout_positions = dropout_positions
|
||||
self.layerdrop = layerdrop
|
||||
self.activation_dropout = activation_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.scale_input = scale_input
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
|
||||
class ParakeetCTCConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`ParakeetForCTC`]. It is used to instantiate a
|
||||
Parakeet CTC model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 1025):
|
||||
Vocabulary size of the model.
|
||||
ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`):
|
||||
Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
|
||||
instance of [`ParakeetForCTC`].
|
||||
ctc_zero_infinity (`bool`, *optional*, defaults to `True`):
|
||||
Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
|
||||
occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
|
||||
of [`ParakeetForCTC`].
|
||||
encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*):
|
||||
The config object or dictionary of the encoder.
|
||||
pad_token_id (`int`, *optional*, defaults to 1024):
|
||||
Padding token id. Also used as blank token id.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import ParakeetForCTC, ParakeetCTCConfig
|
||||
|
||||
>>> # Initializing a Parakeet configuration
|
||||
>>> configuration = ParakeetCTCConfig()
|
||||
|
||||
>>> # Initializing a model from the configuration
|
||||
>>> model = ParakeetForCTC(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
|
||||
This configuration class is based on the Parakeet CTC architecture from NVIDIA NeMo. You can find more details
|
||||
and pre-trained models at [nvidia/parakeet-ctc-1.1b](https://huggingface.co/nvidia/parakeet-ctc-1.1b).
|
||||
"""
|
||||
|
||||
model_type = "parakeet_ctc"
|
||||
sub_configs = {"encoder_config": ParakeetEncoderConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=1025,
|
||||
ctc_loss_reduction="mean",
|
||||
ctc_zero_infinity=True,
|
||||
encoder_config: Union[dict, ParakeetEncoderConfig] = None,
|
||||
pad_token_id=1024,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.ctc_loss_reduction = ctc_loss_reduction
|
||||
self.ctc_zero_infinity = ctc_zero_infinity
|
||||
|
||||
if isinstance(encoder_config, dict):
|
||||
self.encoder_config = ParakeetEncoderConfig(**encoder_config)
|
||||
elif encoder_config is None:
|
||||
self.encoder_config = ParakeetEncoderConfig()
|
||||
|
||||
self.encoder_config = self.encoder_config
|
||||
self.initializer_range = self.encoder_config.initializer_range
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`ParakeetCTCConfig`] (or a derived class) from parakeet encoder model configuration.
|
||||
|
||||
Returns:
|
||||
[`ParakeetCTCConfig`]: An instance of a configuration object
|
||||
"""
|
||||
|
||||
return cls(encoder_config=encoder_config.to_dict(), **kwargs)
|
||||
|
||||
|
||||
__all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig"]
|
315
src/transformers/models/parakeet/convert_nemo_to_hf.py
Normal file
315
src/transformers/models/parakeet/convert_nemo_to_hf.py
Normal file
@ -0,0 +1,315 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import re
|
||||
import tarfile
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from tokenizers import AddedToken
|
||||
|
||||
from transformers import (
|
||||
ParakeetCTCConfig,
|
||||
ParakeetFeatureExtractor,
|
||||
ParakeetForCTC,
|
||||
ParakeetProcessor,
|
||||
ParakeetTokenizerFast,
|
||||
)
|
||||
from transformers.convert_slow_tokenizer import ParakeetConverter
|
||||
from transformers.utils.hub import cached_file
|
||||
|
||||
|
||||
NEMO_TO_HF_WEIGHT_MAPPING = {
|
||||
r"encoder\.pre_encode\.conv\.": r"encoder.subsampling.layers.",
|
||||
r"encoder\.pre_encode\.out\.": r"encoder.subsampling.linear.",
|
||||
r"encoder\.pos_enc\.": r"encoder.encode_positions.",
|
||||
r"encoder\.layers\.(\d+)\.conv\.batch_norm\.": r"encoder.layers.\1.conv.norm.",
|
||||
r"decoder\.decoder_layers\.0\.(weight|bias)": r"ctc_head.\1",
|
||||
r"linear_([kv])": r"\1_proj",
|
||||
r"linear_out": r"o_proj",
|
||||
r"linear_q": r"q_proj",
|
||||
r"pos_bias_([uv])": r"bias_\1",
|
||||
r"linear_pos": r"relative_k_proj",
|
||||
}
|
||||
|
||||
|
||||
def convert_key(key, mapping):
|
||||
for pattern, replacement in mapping.items():
|
||||
key = re.sub(pattern, replacement, key)
|
||||
return key
|
||||
|
||||
|
||||
def extract_nemo_archive(nemo_file_path: str, extract_dir: str) -> dict[str, str]:
|
||||
"""
|
||||
Extract .nemo file (tar archive) and return paths to important files.
|
||||
|
||||
Args:
|
||||
nemo_file_path: Path to .nemo file
|
||||
extract_dir: Directory to extract to
|
||||
|
||||
Returns:
|
||||
Dictionary with paths to model.pt, model_config.yaml, etc.
|
||||
"""
|
||||
print(f"Extracting NeMo archive: {nemo_file_path}")
|
||||
|
||||
with tarfile.open(nemo_file_path, "r", encoding="utf-8") as tar:
|
||||
tar.extractall(extract_dir)
|
||||
|
||||
# Log all extracted files for debugging
|
||||
all_files = []
|
||||
for root, dirs, files in os.walk(extract_dir):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
all_files.append(file_path)
|
||||
|
||||
print(f"All extracted files: {[os.path.basename(f) for f in all_files]}")
|
||||
|
||||
# Find important files with more robust detection
|
||||
model_files = {}
|
||||
for root, dirs, files in os.walk(extract_dir):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
file_lower = file.lower()
|
||||
|
||||
# Look for model weights with various common names
|
||||
if (
|
||||
file.endswith(".pt")
|
||||
or file.endswith(".pth")
|
||||
or file.endswith(".ckpt")
|
||||
or file.endswith(".bin")
|
||||
or "model" in file_lower
|
||||
and ("weight" in file_lower or "state" in file_lower)
|
||||
or file_lower == "model.pt"
|
||||
or file_lower == "pytorch_model.bin"
|
||||
or file_lower == "model_weights.ckpt"
|
||||
):
|
||||
model_files["model_weights"] = file_path
|
||||
print(f"Found model weights: {file}")
|
||||
|
||||
# Look for config files
|
||||
elif (
|
||||
file == "model_config.yaml"
|
||||
or file == "config.yaml"
|
||||
or (file.endswith(".yaml") and "config" in file_lower)
|
||||
):
|
||||
if "model_config" not in model_files: # Prefer model_config.yaml
|
||||
model_files["model_config"] = file_path
|
||||
print(f"Found config file: {file}")
|
||||
if file == "model_config.yaml":
|
||||
model_files["model_config"] = file_path # Override with preferred name
|
||||
|
||||
# Look for vocabulary files
|
||||
elif (
|
||||
file.endswith(".vocab")
|
||||
or file.endswith(".model")
|
||||
or file.endswith(".txt")
|
||||
or ("tokenizer" in file_lower and (file.endswith(".vocab") or file.endswith(".model")))
|
||||
):
|
||||
# Prefer .vocab files over others
|
||||
if "tokenizer_model_file" not in model_files or file.endswith(".model"):
|
||||
model_files["tokenizer_model_file"] = file_path
|
||||
print(f"Found tokenizer model file: {file}")
|
||||
else:
|
||||
print(f"Found additional vocabulary file (using existing): {file}")
|
||||
|
||||
print(f"Found model files: {list(model_files.keys())}")
|
||||
|
||||
# Validate that we found the required files
|
||||
if "model_weights" not in model_files:
|
||||
raise FileNotFoundError(
|
||||
f"Could not find model weights file in {nemo_file_path}. "
|
||||
f"Expected files with extensions: .pt, .pth, .ckpt, .bin. "
|
||||
f"Found files: {[os.path.basename(f) for f in all_files]}"
|
||||
)
|
||||
|
||||
if "model_config" not in model_files:
|
||||
raise FileNotFoundError(
|
||||
f"Could not find model config file in {nemo_file_path}. "
|
||||
f"Expected: model_config.yaml or config.yaml. "
|
||||
f"Found files: {[os.path.basename(f) for f in all_files]}"
|
||||
)
|
||||
|
||||
return model_files
|
||||
|
||||
|
||||
def write_processor(nemo_config: dict, model_files, output_dir, push_to_repo_id=None):
|
||||
tokenizer_converted = ParakeetConverter(model_files["tokenizer_model_file"]).converted()
|
||||
tokenizer_converted_fast = ParakeetTokenizerFast(
|
||||
tokenizer_object=tokenizer_converted,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
tokenizer_converted_fast.add_tokens(
|
||||
[AddedToken("<unk>", normalized=False, special=True), AddedToken("<pad>", normalized=False, special=True)]
|
||||
)
|
||||
tokenizer_converted_fast.add_special_tokens(
|
||||
{
|
||||
"pad_token": AddedToken("<pad>", normalized=False, special=True),
|
||||
"unk_token": AddedToken("<unk>", normalized=False, special=True),
|
||||
}
|
||||
)
|
||||
|
||||
feature_extractor_keys_to_ignore = ["_target_", "pad_to", "frame_splicing", "dither", "normalize", "window", "log"]
|
||||
feature_extractor_config_keys_mapping = {
|
||||
"sample_rate": "sampling_rate",
|
||||
"window_size": "win_length",
|
||||
"window_stride": "hop_length",
|
||||
"window": "window",
|
||||
"n_fft": "n_fft",
|
||||
"log": "log",
|
||||
"features": "feature_size",
|
||||
"dither": "dither",
|
||||
"pad_to": "pad_to",
|
||||
"pad_value": "padding_value",
|
||||
"frame_splicing": "frame_splicing",
|
||||
"preemphasis": "preemphasis",
|
||||
"hop_length": "hop_length",
|
||||
}
|
||||
converted_feature_extractor_config = {}
|
||||
|
||||
for key, value in nemo_config["preprocessor"].items():
|
||||
if key in feature_extractor_keys_to_ignore:
|
||||
continue
|
||||
if key in feature_extractor_config_keys_mapping:
|
||||
if key in ["window_size", "window_stride"]:
|
||||
value = int(value * nemo_config["preprocessor"]["sample_rate"])
|
||||
converted_feature_extractor_config[feature_extractor_config_keys_mapping[key]] = value
|
||||
else:
|
||||
raise ValueError(f"Key {key} not found in feature_extractor_keys_mapping")
|
||||
|
||||
feature_extractor = ParakeetFeatureExtractor(**converted_feature_extractor_config)
|
||||
|
||||
processor = ParakeetProcessor(
|
||||
feature_extractor=feature_extractor,
|
||||
tokenizer=tokenizer_converted_fast,
|
||||
)
|
||||
processor.save_pretrained(output_dir)
|
||||
|
||||
if push_to_repo_id:
|
||||
processor.push_to_hub(push_to_repo_id)
|
||||
|
||||
|
||||
def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None):
|
||||
encoder_keys_to_ignore = [
|
||||
"att_context_size",
|
||||
"causal_downsampling",
|
||||
"stochastic_depth_start_layer",
|
||||
"feat_out",
|
||||
"stochastic_depth_drop_prob",
|
||||
"_target_",
|
||||
"ff_expansion_factor",
|
||||
"untie_biases",
|
||||
"att_context_style",
|
||||
"self_attention_model",
|
||||
"conv_norm_type",
|
||||
"subsampling",
|
||||
"stochastic_depth_mode",
|
||||
"conv_context_size",
|
||||
"dropout_pre_encoder",
|
||||
]
|
||||
enocder_config_keys_mapping = {
|
||||
"d_model": "hidden_size",
|
||||
"n_heads": "num_attention_heads",
|
||||
"n_layers": "num_hidden_layers",
|
||||
"feat_in": "num_mel_bins",
|
||||
"conv_kernel_size": "conv_kernel_size",
|
||||
"subsampling_factor": "subsampling_factor",
|
||||
"subsampling_conv_channels": "subsampling_conv_channels",
|
||||
"pos_emb_max_len": "max_position_embeddings",
|
||||
"dropout": "dropout",
|
||||
"dropout_emb": "dropout_positions",
|
||||
"dropout_att": "attention_dropout",
|
||||
"xscaling": "scale_input",
|
||||
}
|
||||
converted_encoder_config = {}
|
||||
|
||||
for key, value in nemo_config["encoder"].items():
|
||||
if key in encoder_keys_to_ignore:
|
||||
continue
|
||||
if key in enocder_config_keys_mapping:
|
||||
converted_encoder_config[enocder_config_keys_mapping[key]] = value
|
||||
else:
|
||||
raise ValueError(f"Key {key} not found in enocder_config_keys_mapping")
|
||||
|
||||
state_dict = torch.load(model_files["model_weights"], map_location="cpu", weights_only=True)
|
||||
converted_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
# Skip preprocessing weights (featurizer components)
|
||||
if key.endswith("featurizer.window") or key.endswith("featurizer.fb"):
|
||||
print(f"Skipping preprocessing weight: {key}")
|
||||
continue
|
||||
converted_key = convert_key(key, NEMO_TO_HF_WEIGHT_MAPPING)
|
||||
converted_state_dict[converted_key] = value
|
||||
|
||||
if model_type == "ctc":
|
||||
model_config = ParakeetCTCConfig(
|
||||
encoder_config=converted_encoder_config,
|
||||
)
|
||||
print("Loading the checkpoint in a Parakeet CTC model.")
|
||||
with torch.device("meta"):
|
||||
model = ParakeetForCTC(model_config)
|
||||
model.load_state_dict(converted_state_dict, strict=True, assign=True)
|
||||
print("Checkpoint loaded successfully.")
|
||||
del model.config._name_or_path
|
||||
|
||||
print("Saving the model.")
|
||||
model.save_pretrained(output_dir)
|
||||
|
||||
if push_to_repo_id:
|
||||
model.push_to_hub(push_to_repo_id)
|
||||
|
||||
del converted_state_dict, model
|
||||
|
||||
# Safety check: reload the converted model
|
||||
gc.collect()
|
||||
print("Reloading the model to check if it's saved correctly.")
|
||||
ParakeetForCTC.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto")
|
||||
print("Model reloaded successfully.")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Model type {model_type} not supported.")
|
||||
|
||||
|
||||
def main(
|
||||
hf_repo_id,
|
||||
output_dir,
|
||||
model_type,
|
||||
push_to_repo_id=None,
|
||||
):
|
||||
nemo_filename = f"{hf_repo_id.split('/')[-1]}.nemo"
|
||||
filepath = cached_file(hf_repo_id, nemo_filename)
|
||||
|
||||
model_files = extract_nemo_archive(filepath, os.path.dirname(filepath))
|
||||
nemo_config = yaml.load(open(model_files["model_config"], "r"), Loader=yaml.FullLoader)
|
||||
|
||||
write_processor(nemo_config, model_files, output_dir, push_to_repo_id)
|
||||
write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--hf_repo_id", required=True, help="Model repo on huggingface.co")
|
||||
parser.add_argument("--model_type", required=True, choices=["ctc"], help="Model type (`ctc`, `tdt`)")
|
||||
parser.add_argument("--output_dir", required=True, help="Output directory for HuggingFace model")
|
||||
parser.add_argument("--push_to_repo_id", help="Repository ID to push the model to on the Hub")
|
||||
args = parser.parse_args()
|
||||
main(
|
||||
args.hf_repo_id,
|
||||
args.output_dir,
|
||||
args.model_type,
|
||||
args.push_to_repo_id,
|
||||
)
|
287
src/transformers/models/parakeet/feature_extraction_parakeet.py
Normal file
287
src/transformers/models/parakeet/feature_extraction_parakeet.py
Normal file
@ -0,0 +1,287 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
|
||||
from ...feature_extraction_utils import BatchFeature
|
||||
from ...utils import TensorType, is_librosa_available, logging
|
||||
from ...utils.import_utils import requires
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
import librosa
|
||||
|
||||
|
||||
EPSILON = 1e-5
|
||||
LOG_ZERO_GUARD_VALUE = 2**-24
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@requires(backends=("torch", "librosa"))
|
||||
class ParakeetFeatureExtractor(SequenceFeatureExtractor):
|
||||
r"""
|
||||
Constructs a Parakeet feature extractor.
|
||||
|
||||
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
|
||||
most of the main methods. Users should refer to this superclass for more information regarding those methods.
|
||||
|
||||
This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time
|
||||
Fourier Transform` which should match pytorch's `torch.stft` equivalent.
|
||||
|
||||
Args:
|
||||
feature_size (`int`, *optional*, defaults to 80):
|
||||
The feature dimension of the extracted features.
|
||||
sampling_rate (`int`, *optional*, defaults to 16000):
|
||||
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
|
||||
hop_length (`int`, *optional*, defaults to 160):
|
||||
Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
|
||||
n_fft (`int`, *optional*, defaults to 512):
|
||||
Size of the Fourier transform.
|
||||
win_length (`int`, *optional*, defaults to 400):
|
||||
The window length for the STFT computation.
|
||||
preemphasis (`float`, *optional*, defaults to 0.97):
|
||||
A preemphasis filter coefficient. 0.0 means no preemphasis filter.
|
||||
padding_value (`float`, *optional*, defaults to 0.0):
|
||||
Padding value used to pad the audio. Should correspond to silences.
|
||||
"""
|
||||
|
||||
model_input_names = ["input_features", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_size=80,
|
||||
sampling_rate=16000,
|
||||
hop_length=160,
|
||||
n_fft=512,
|
||||
win_length=400,
|
||||
preemphasis=0.97,
|
||||
padding_value=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
|
||||
|
||||
self.hop_length = hop_length
|
||||
self.n_fft = n_fft
|
||||
self.win_length = win_length
|
||||
self.preemphasis = preemphasis
|
||||
|
||||
# TODO: @eustlb, for now we use librosa to compute the mel filters
|
||||
# indeed mel_filter_bank uses np.float64 (while librosa uses np.float32), giving numerical differences
|
||||
# self.mel_filters = mel_filter_bank(
|
||||
# num_frequency_bins=n_fft // 2 + 1,
|
||||
# num_mel_filters=feature_size,
|
||||
# min_frequency=0.0,
|
||||
# max_frequency=sampling_rate / 2,
|
||||
# sampling_rate=sampling_rate,
|
||||
# norm="slaney",
|
||||
# mel_scale="slaney",
|
||||
# )
|
||||
mel_filters = librosa.filters.mel(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=feature_size, fmin=0.0, fmax=sampling_rate / 2, norm="slaney"
|
||||
)
|
||||
self.mel_filters = torch.from_numpy(mel_filters).to(torch.float32)
|
||||
|
||||
def _torch_extract_fbank_features(self, waveform, device="cpu"):
|
||||
# spectrogram
|
||||
window = torch.hann_window(self.win_length, periodic=False, device=device)
|
||||
stft = torch.stft(
|
||||
waveform,
|
||||
self.n_fft,
|
||||
hop_length=self.hop_length,
|
||||
win_length=self.win_length,
|
||||
window=window,
|
||||
return_complex=True,
|
||||
pad_mode="constant",
|
||||
)
|
||||
# Let's math original implementation
|
||||
# magnitudes = torch.abs(stft) ** 2
|
||||
magnitudes = torch.view_as_real(stft)
|
||||
magnitudes = torch.sqrt(magnitudes.pow(2).sum(-1))
|
||||
magnitudes = magnitudes.pow(2)
|
||||
|
||||
# log mel spectrogram
|
||||
mel_filters = self.mel_filters.to(device)
|
||||
mel_spec = mel_filters @ magnitudes
|
||||
mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE)
|
||||
|
||||
# (batch_size, num_mel_filters, num_frames) -> (batch_size, num_frames, num_mel_filters)
|
||||
mel_spec = mel_spec.permute(0, 2, 1)
|
||||
|
||||
return mel_spec
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
|
||||
truncation: bool = False,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
padding: Optional[str] = "longest",
|
||||
max_length: Optional[int] = None,
|
||||
sampling_rate: Optional[int] = None,
|
||||
do_normalize: Optional[bool] = None,
|
||||
device: Optional[str] = "cpu",
|
||||
return_token_timestamps: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
|
||||
the STFT computation if available, otherwise a slower NumPy based one.
|
||||
|
||||
Args:
|
||||
raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
|
||||
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
|
||||
values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
|
||||
stereo, i.e. single float per timestep.
|
||||
truncation (`bool`, *optional*, default to `True`):
|
||||
Activates truncation to cut input sequences longer than *max_length* to *max_length*.
|
||||
pad_to_multiple_of (`int`, *optional*, defaults to None):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
||||
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
|
||||
return_attention_mask (`bool`, *optional*):
|
||||
Whether to return the attention mask. If left to the default, will return the attention mask according
|
||||
to the specific feature_extractor's default.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
|
||||
<Tip>
|
||||
|
||||
For Parakeet models, `attention_mask` should always be passed for batched inference, to avoid subtle
|
||||
bugs.
|
||||
|
||||
</Tip>
|
||||
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return Numpy `np.ndarray` objects.
|
||||
sampling_rate (`int`, *optional*):
|
||||
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
|
||||
`sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
|
||||
pipeline.
|
||||
padding_value (`float`, *optional*, defaults to 0.0):
|
||||
The value that is used to fill the padding values / vectors.
|
||||
do_normalize (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
|
||||
improve the performance of the model.
|
||||
device (`str`, *optional*, defaults to `'cpu'`):
|
||||
Specifies the device for computation of the log-mel spectrogram of audio signals in the
|
||||
`_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
|
||||
return_token_timestamps (`bool`, *optional*, defaults to `None`):
|
||||
Deprecated. Use `return_attention_mask` instead from which the number of frames can be inferred.
|
||||
|
||||
Whether or not to return the number of frames of the input raw_speech.
|
||||
These num_frames can be used by the model to compute word level timestamps.
|
||||
"""
|
||||
if sampling_rate is not None:
|
||||
if sampling_rate != self.sampling_rate:
|
||||
raise ValueError(
|
||||
f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
|
||||
f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
|
||||
f" was sampled with {self.sampling_rate} and not {sampling_rate}."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
# Convert to torch tensor
|
||||
if isinstance(raw_speech, np.ndarray):
|
||||
raw_speech = torch.tensor(raw_speech)
|
||||
elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], np.ndarray):
|
||||
raw_speech = [torch.tensor(speech) for speech in raw_speech]
|
||||
|
||||
is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1
|
||||
if is_batched_torch and len(raw_speech.shape) > 2:
|
||||
logger.warning(
|
||||
f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
|
||||
"We will take the mean of the channels to convert to mono."
|
||||
)
|
||||
raw_speech = raw_speech.mean(-1)
|
||||
|
||||
is_batched_sequence = isinstance(raw_speech, (list, tuple))
|
||||
if is_batched_sequence:
|
||||
for speech in raw_speech:
|
||||
if len(speech.shape) > 1:
|
||||
logger.warning(
|
||||
f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
|
||||
"We will take the mean of the channels to convert to mono."
|
||||
)
|
||||
speech = speech.mean(-1)
|
||||
|
||||
if is_batched_torch or is_batched_sequence:
|
||||
raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech]
|
||||
else:
|
||||
raw_speech = [raw_speech[:, None].to(torch.float32)]
|
||||
|
||||
audio_lengths = [len(speech) for speech in raw_speech]
|
||||
batched_speech = BatchFeature({"input_features": raw_speech, "audio_lengths": audio_lengths})
|
||||
|
||||
padded_inputs = self.pad(
|
||||
batched_speech,
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_features = padded_inputs.input_features.squeeze(-1)
|
||||
|
||||
# preemphasis
|
||||
if self.preemphasis is not None:
|
||||
timemask = torch.arange(input_features.shape[1], device=input_features.device).unsqueeze(
|
||||
0
|
||||
) < padded_inputs.audio_lengths.unsqueeze(1)
|
||||
input_features = torch.cat(
|
||||
[input_features[:, :1], input_features[:, 1:] - self.preemphasis * input_features[:, :-1]], dim=1
|
||||
)
|
||||
input_features = input_features.masked_fill(~timemask, 0.0)
|
||||
|
||||
input_features = self._torch_extract_fbank_features(input_features, device)
|
||||
features_lengths = torch.floor_divide(
|
||||
padded_inputs.audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length
|
||||
)
|
||||
attention_mask = torch.arange(input_features.shape[1], device=device)[None, :] < features_lengths[:, None]
|
||||
|
||||
# normalize mel features, ignoring padding
|
||||
mask = attention_mask.unsqueeze(-1)
|
||||
input_features_masked = input_features * mask
|
||||
mean = input_features_masked.sum(dim=1) / features_lengths.unsqueeze(-1)
|
||||
mean = mean.unsqueeze(1)
|
||||
variance = ((input_features_masked - mean) ** 2 * mask).sum(dim=1) / (features_lengths - 1).unsqueeze(-1)
|
||||
std = torch.sqrt(variance).unsqueeze(1)
|
||||
input_features = (input_features - mean) / (std + EPSILON)
|
||||
input_features *= mask
|
||||
|
||||
return BatchFeature(
|
||||
data={
|
||||
"input_features": input_features,
|
||||
"attention_mask": attention_mask,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ParakeetFeatureExtractor"]
|
744
src/transformers/models/parakeet/modeling_parakeet.py
Normal file
744
src/transformers/models/parakeet/modeling_parakeet.py
Normal file
@ -0,0 +1,744 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from src/transformers/models/parakeet/modular_parakeet.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_parakeet.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
|
||||
|
||||
|
||||
class ParakeetEncoderRelPositionalEncoding(nn.Module):
|
||||
"""Relative positional encoding for Parakeet."""
|
||||
|
||||
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
||||
|
||||
def __init__(self, config: ParakeetEncoderConfig, device=None):
|
||||
super().__init__()
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
base = 10000.0
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (
|
||||
torch.arange(0, config.hidden_size, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
|
||||
/ config.hidden_size
|
||||
)
|
||||
)
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
seq_length = hidden_states.shape[1]
|
||||
if seq_length > self.max_position_embeddings:
|
||||
raise ValueError(
|
||||
f"Sequence Length: {seq_length} has to be less or equal than "
|
||||
f"config.max_position_embeddings {self.max_position_embeddings}."
|
||||
)
|
||||
|
||||
position_ids = torch.arange(seq_length - 1, -seq_length, -1, device=hidden_states.device)
|
||||
inv_freq_expanded = (
|
||||
self.inv_freq[None, :, None].float().expand(hidden_states.shape[0], -1, 1).to(hidden_states.device)
|
||||
)
|
||||
position_ids_expanded = position_ids[None, None, :].float()
|
||||
|
||||
device_type = (
|
||||
hidden_states.device.type
|
||||
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
|
||||
else "cpu"
|
||||
)
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
sin = freqs.sin()
|
||||
cos = freqs.cos()
|
||||
# interleave sin and cos
|
||||
pos_embed = torch.stack([sin, cos], dim=-1)
|
||||
pos_embed = pos_embed.reshape(*pos_embed.shape[:-2], -1)
|
||||
|
||||
return pos_embed.to(dtype=hidden_states.dtype)
|
||||
|
||||
|
||||
class ParakeetEncoderFeedForward(nn.Module):
|
||||
def __init__(self, config: ParakeetEncoderConfig):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.attention_bias)
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.attention_bias)
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.activation(self.linear1(hidden_states))
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
||||
hidden_states = self.linear2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ParakeetEncoderConvolutionModule(nn.Module):
|
||||
def __init__(self, config: ParakeetEncoderConfig, module_config=None):
|
||||
"""
|
||||
Args:
|
||||
config (ParakeetEncoderConfig): Configuration for the model.
|
||||
module_config (dict): Configuration for the module (e.g., encoder or decoder).
|
||||
"""
|
||||
super().__init__()
|
||||
channels = config.hidden_size
|
||||
# kernel_size should be an odd number for 'SAME' padding
|
||||
if module_config is None:
|
||||
# e.g. using `ParakeetEncoderEncoderConfig` in src/transformers/models/parakeet_encoder/configuration_parakeet_encoder.py
|
||||
kernel_size = config.conv_kernel_size
|
||||
self.activation = ACT2FN[getattr(config, "hidden_act", "silu")]
|
||||
else:
|
||||
kernel_size = module_config["kernel_size"]
|
||||
self.activation = ACT2FN[module_config.get("activation", "silu")]
|
||||
self.padding = (kernel_size - 1) // 2
|
||||
self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None):
|
||||
"""
|
||||
Compute convolution module.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
|
||||
attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
|
||||
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
|
||||
# GLU mechanism, (batch_size, 2*channel, dim)
|
||||
hidden_states = self.pointwise_conv1(hidden_states)
|
||||
# (batch_size, channel, dim)
|
||||
hidden_states = nn.functional.glu(hidden_states, dim=1)
|
||||
|
||||
# Apply padding mask before convolution
|
||||
if attention_mask is not None:
|
||||
all_masked_rows = torch.all(~attention_mask, dim=-1)
|
||||
hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
hidden_states = self.depthwise_conv(hidden_states)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.pointwise_conv2(hidden_states)
|
||||
|
||||
return hidden_states.transpose(1, 2)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class ParakeetEncoderAttention(nn.Module):
|
||||
"""Multi-head attention with relative positional encoding. See section 3.3 of https://huggingface.co/papers/1901.02860."""
|
||||
|
||||
def __init__(self, config: ParakeetEncoderConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
# W_{k,R} projection
|
||||
self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
|
||||
# global content bias
|
||||
self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
|
||||
# global positional bias
|
||||
self.bias_v = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
|
||||
|
||||
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
hidden_shape = (batch_size, seq_length, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_states_with_bias_u = query_states + self.bias_u.view(
|
||||
1, self.config.num_attention_heads, 1, self.head_dim
|
||||
)
|
||||
query_states_with_bias_v = query_states + self.bias_v.view(
|
||||
1, self.config.num_attention_heads, 1, self.head_dim
|
||||
)
|
||||
|
||||
relative_key_states = self.relative_k_proj(position_embeddings)
|
||||
relative_key_states = relative_key_states.view(batch_size, -1, self.config.num_attention_heads, self.head_dim)
|
||||
|
||||
# terms (b) and (d)
|
||||
matrix_bd = query_states_with_bias_v @ relative_key_states.permute(0, 2, 3, 1)
|
||||
matrix_bd = self._rel_shift(matrix_bd)
|
||||
matrix_bd = matrix_bd[..., :seq_length]
|
||||
matrix_bd = matrix_bd * self.scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
# here the original codebase uses -10000.0 rather than float("-inf") and then manual masked fill with 0.0s
|
||||
# see: https://github.com/NVIDIA-NeMo/NeMo/blob/8cfedd7203462cb251a914e700e5605444277561/nemo/collections/asr/parts/submodules/multi_head_attention.py#L320-L340
|
||||
# we rather went for a straight-forward approach with float("-inf")
|
||||
matrix_bd = matrix_bd.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
||||
|
||||
# will compute matrix_ac - terms (a) and (c) - and add matrix_bd
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query=query_states_with_bias_u,
|
||||
key=key_states,
|
||||
value=value_states,
|
||||
attention_mask=matrix_bd,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _rel_shift(self, attention_scores):
|
||||
"""Relative position shift for Shaw et al. style attention. See appendix B of https://huggingface.co/papers/1901.02860."""
|
||||
batch_size, num_heads, query_length, position_length = attention_scores.shape
|
||||
attention_scores = nn.functional.pad(attention_scores, pad=(1, 0))
|
||||
attention_scores = attention_scores.view(batch_size, num_heads, -1, query_length)
|
||||
attention_scores = attention_scores[:, :, 1:].view(batch_size, num_heads, query_length, position_length)
|
||||
return attention_scores
|
||||
|
||||
|
||||
class ParakeetEncoderSubsamplingConv2D(nn.Module):
|
||||
def __init__(self, config: ParakeetEncoderConfig):
|
||||
super().__init__()
|
||||
|
||||
self.kernel_size = config.subsampling_conv_kernel_size
|
||||
self.stride = config.subsampling_conv_stride
|
||||
self.channels = config.subsampling_conv_channels
|
||||
self.padding = (self.kernel_size - 1) // 2
|
||||
self.num_layers = int(math.log2(config.subsampling_factor))
|
||||
|
||||
# define layers
|
||||
self.layers = nn.ModuleList()
|
||||
self.layers.append(
|
||||
nn.Conv2d(1, self.channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
|
||||
)
|
||||
self.layers.append(nn.ReLU())
|
||||
for i in range(self.num_layers - 1):
|
||||
# depthwise conv
|
||||
self.layers.append(
|
||||
nn.Conv2d(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
groups=self.channels,
|
||||
)
|
||||
)
|
||||
# pointwise conv
|
||||
self.layers.append(nn.Conv2d(self.channels, self.channels, kernel_size=1))
|
||||
# activation
|
||||
self.layers.append(nn.ReLU())
|
||||
|
||||
out_length = config.num_mel_bins // (self.stride**self.num_layers)
|
||||
self.linear = nn.Linear(config.subsampling_conv_channels * out_length, config.hidden_size, bias=True)
|
||||
|
||||
def _get_output_length(self, input_lengths: torch.Tensor, conv_layer: nn.Conv2d):
|
||||
if hasattr(conv_layer, "stride") and conv_layer.stride != (1, 1):
|
||||
padding = conv_layer.padding
|
||||
kernel_size = conv_layer.kernel_size[0]
|
||||
stride = conv_layer.stride[0]
|
||||
|
||||
output_lengths = (input_lengths + padding[0] + padding[1] - kernel_size) // stride + 1
|
||||
return output_lengths
|
||||
|
||||
return input_lengths
|
||||
|
||||
def forward(self, input_features: torch.Tensor, attention_mask: torch.Tensor = None):
|
||||
hidden_states = input_features.unsqueeze(1)
|
||||
current_lengths = attention_mask.sum(-1) if attention_mask is not None else None
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
# mask the hidden states
|
||||
if isinstance(layer, nn.Conv2d) and attention_mask is not None:
|
||||
current_lengths = self._get_output_length(current_lengths, layer)
|
||||
current_seq_length = hidden_states.shape[2]
|
||||
channel_mask = (
|
||||
torch.arange(current_seq_length, device=attention_mask.device) < current_lengths[:, None]
|
||||
)
|
||||
hidden_states *= channel_mask[:, None, :, None]
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(hidden_states.shape[0], hidden_states.shape[2], -1)
|
||||
hidden_states = self.linear(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ParakeetEncoderBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config: ParakeetEncoderConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.feed_forward1 = ParakeetEncoderFeedForward(config)
|
||||
self.self_attn = ParakeetEncoderAttention(config, layer_idx)
|
||||
self.conv = ParakeetEncoderConvolutionModule(config)
|
||||
self.feed_forward2 = ParakeetEncoderFeedForward(config)
|
||||
|
||||
self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size)
|
||||
self.norm_self_att = nn.LayerNorm(config.hidden_size)
|
||||
self.norm_conv = nn.LayerNorm(config.hidden_size)
|
||||
self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size)
|
||||
self.norm_out = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
|
||||
hidden_states = residual + 0.5 * hidden_states # the conformer architecture uses a factor of 0.5
|
||||
|
||||
normalized_hidden_states = self.norm_self_att(hidden_states)
|
||||
attn_output, _ = self.self_attn(
|
||||
hidden_states=normalized_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
|
||||
hidden_states = hidden_states + conv_output
|
||||
|
||||
ff2_output = self.feed_forward2(self.norm_feed_forward2(hidden_states))
|
||||
hidden_states = hidden_states + 0.5 * ff2_output # the conformer architecture uses a factor of 0.5
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class ParakeetPreTrainedModel(PreTrainedModel):
|
||||
config: ParakeetCTCConfig
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "input_features"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ParakeetEncoderBlock"]
|
||||
_supports_flat_attention_mask = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
# TODO: @eustlb, add support when flash attention supports custom attention bias
|
||||
_supports_flash_attn = False
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
"hidden_states": ParakeetEncoderBlock,
|
||||
"attentions": ParakeetEncoderAttention,
|
||||
}
|
||||
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
|
||||
if hasattr(self.config, "initializer_range"):
|
||||
std = self.config.initializer_range
|
||||
else:
|
||||
# 0.02 is the standard default value accross the library
|
||||
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
|
||||
|
||||
if isinstance(module, ParakeetEncoderAttention):
|
||||
# Initialize positional bias parameters
|
||||
module.bias_u.data.normal_(mean=0.0, std=std)
|
||||
module.bias_v.data.normal_(mean=0.0, std=std)
|
||||
|
||||
def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
|
||||
encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config
|
||||
|
||||
kernel_size = encoder_config.subsampling_conv_kernel_size
|
||||
stride = encoder_config.subsampling_conv_stride
|
||||
num_layers = int(math.log2(encoder_config.subsampling_factor))
|
||||
|
||||
all_paddings = (kernel_size - 1) // 2 * 2
|
||||
add_pad = all_paddings - kernel_size
|
||||
lengths = input_lengths
|
||||
|
||||
for _ in range(num_layers):
|
||||
lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + 1.0
|
||||
lengths = torch.floor(lengths)
|
||||
|
||||
return lengths.to(dtype=torch.int)
|
||||
|
||||
def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: Optional[int] = None):
|
||||
"""
|
||||
Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
|
||||
when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
|
||||
"""
|
||||
output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
|
||||
# Use target_length if provided, otherwise use max length in batch
|
||||
max_length = target_length if target_length is not None else output_lengths.max()
|
||||
attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None]
|
||||
return attention_mask
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Parakeet Encoder model, based on the [Fast Conformer architecture](https://huggingface.co/papers/2305.05084).
|
||||
"""
|
||||
)
|
||||
class ParakeetEncoder(ParakeetPreTrainedModel):
|
||||
config: ParakeetEncoderConfig
|
||||
base_model_prefix = "encoder"
|
||||
|
||||
def __init__(self, config: ParakeetEncoderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.dropout = config.dropout
|
||||
self.dropout_positions = config.dropout_positions
|
||||
self.layerdrop = config.layerdrop
|
||||
|
||||
self.input_scale = math.sqrt(config.hidden_size) if config.scale_input else 1.0
|
||||
self.subsampling = ParakeetEncoderSubsamplingConv2D(config)
|
||||
self.encode_positions = ParakeetEncoderRelPositionalEncoding(config)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[ParakeetEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@auto_docstring
|
||||
@check_model_inputs
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_features: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutput:
|
||||
r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, ParakeetEncoder
|
||||
>>> from datasets import load_dataset, Audio
|
||||
|
||||
>>> model_id = "nvidia/parakeet-ctc-1.1b"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_id)
|
||||
>>> encoder = ParakeetEncoder.from_pretrained(model_id)
|
||||
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
|
||||
|
||||
>>> inputs = processor(ds[0]["audio"]["array"])
|
||||
>>> encoder_outputs = encoder(**inputs)
|
||||
|
||||
>>> print(encoder_outputs.last_hidden_state.shape)
|
||||
```
|
||||
"""
|
||||
|
||||
hidden_states = self.subsampling(input_features, attention_mask)
|
||||
hidden_states = hidden_states * self.input_scale
|
||||
position_embeddings = self.encode_positions(hidden_states)
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
position_embeddings = nn.functional.dropout(
|
||||
position_embeddings, p=self.dropout_positions, training=self.training
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
|
||||
attention_mask = attention_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1)
|
||||
attention_mask = attention_mask & attention_mask.transpose(1, 2)
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
||||
to_drop = False
|
||||
if self.training:
|
||||
dropout_probability = torch.rand([])
|
||||
if dropout_probability < self.layerdrop: # skip the layer
|
||||
to_drop = True
|
||||
|
||||
if not to_drop:
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return BaseModelOutput(last_hidden_state=hidden_states)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParakeetGenerateOutput(ModelOutput):
|
||||
"""
|
||||
Outputs of Parakeet models.
|
||||
|
||||
Args:
|
||||
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
||||
if all batches finished early due to the `eos_token_id`.
|
||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
||||
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor
|
||||
logits: Optional[tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Parakeet Encoder with a Connectionist Temporal Classification (CTC) head.
|
||||
"""
|
||||
)
|
||||
class ParakeetForCTC(ParakeetPreTrainedModel):
|
||||
config: ParakeetCTCConfig
|
||||
|
||||
def __init__(self, config: ParakeetCTCConfig):
|
||||
super().__init__(config)
|
||||
self.encoder = ParakeetEncoder(config.encoder_config)
|
||||
# Conv rather than linear to be consistent with NeMO decoding layer
|
||||
self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@auto_docstring
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_features: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutput:
|
||||
r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, ParakeetForCTC
|
||||
>>> from datasets import load_dataset, Audio
|
||||
|
||||
>>> model_id = "nvidia/parakeet-ctc-1.1b"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_id)
|
||||
>>> model = ParakeetForCTC.from_pretrained(model_id)
|
||||
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
|
||||
|
||||
>>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> print(outputs.loss)
|
||||
```"""
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
input_features=input_features,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs.last_hidden_state
|
||||
logits = self.ctc_head(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# retrieve loss input_lengths from attention_mask
|
||||
attention_mask = (
|
||||
attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
|
||||
)
|
||||
input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
|
||||
|
||||
# assuming that padded tokens are filled with -100
|
||||
# when not being attended to
|
||||
labels_mask = labels != self.config.pad_token_id
|
||||
target_lengths = labels_mask.sum(-1)
|
||||
flattened_targets = labels.masked_select(labels_mask)
|
||||
|
||||
# ctc_loss doesn't support fp16
|
||||
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
||||
|
||||
with torch.backends.cudnn.flags(enabled=False):
|
||||
loss = nn.functional.ctc_loss(
|
||||
log_probs,
|
||||
flattened_targets,
|
||||
input_lengths,
|
||||
target_lengths,
|
||||
blank=self.config.pad_token_id,
|
||||
reduction=self.config.ctc_loss_reduction,
|
||||
zero_infinity=self.config.ctc_zero_infinity,
|
||||
)
|
||||
|
||||
return CausalLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_features: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict_in_generate: bool = False,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[ParakeetGenerateOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, ParakeetForCTC
|
||||
>>> from datasets import load_dataset, Audio
|
||||
|
||||
>>> model_id = "nvidia/parakeet-ctc-1.1b"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_id)
|
||||
>>> model = ParakeetForCTC.from_pretrained(model_id)
|
||||
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
|
||||
|
||||
>>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
|
||||
>>> predicted_ids = model.generate(**inputs)
|
||||
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
||||
|
||||
>>> print(transcription)
|
||||
```
|
||||
"""
|
||||
kwargs["return_dict"] = True
|
||||
outputs: CausalLMOutput = self.forward(
|
||||
input_features=input_features,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# greedy decoding
|
||||
sequences = outputs.logits.argmax(dim=-1)
|
||||
|
||||
# mask out padded tokens
|
||||
if attention_mask is not None:
|
||||
attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequences.shape[1])
|
||||
sequences[~attention_mask] = self.config.pad_token_id
|
||||
|
||||
if return_dict_in_generate:
|
||||
return ParakeetGenerateOutput(
|
||||
sequences=sequences,
|
||||
logits=outputs.logits,
|
||||
attentions=outputs.attentions,
|
||||
hidden_states=outputs.hidden_states,
|
||||
)
|
||||
|
||||
return sequences
|
||||
|
||||
|
||||
__all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"]
|
628
src/transformers/models/parakeet/modular_parakeet.py
Normal file
628
src/transformers/models/parakeet/modular_parakeet.py
Normal file
@ -0,0 +1,628 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Parakeet model."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
|
||||
from ...utils.generic import check_model_inputs
|
||||
from ..fastspeech2_conformer.modeling_fastspeech2_conformer import FastSpeech2ConformerConvolutionModule
|
||||
from ..llama.modeling_llama import LlamaAttention, eager_attention_forward
|
||||
from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
|
||||
|
||||
|
||||
class ParakeetEncoderRelPositionalEncoding(nn.Module):
|
||||
"""Relative positional encoding for Parakeet."""
|
||||
|
||||
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
||||
|
||||
def __init__(self, config: ParakeetEncoderConfig, device=None):
|
||||
super().__init__()
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
base = 10000.0
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (
|
||||
torch.arange(0, config.hidden_size, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
|
||||
/ config.hidden_size
|
||||
)
|
||||
)
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
seq_length = hidden_states.shape[1]
|
||||
if seq_length > self.max_position_embeddings:
|
||||
raise ValueError(
|
||||
f"Sequence Length: {seq_length} has to be less or equal than "
|
||||
f"config.max_position_embeddings {self.max_position_embeddings}."
|
||||
)
|
||||
|
||||
position_ids = torch.arange(seq_length - 1, -seq_length, -1, device=hidden_states.device)
|
||||
inv_freq_expanded = (
|
||||
self.inv_freq[None, :, None].float().expand(hidden_states.shape[0], -1, 1).to(hidden_states.device)
|
||||
)
|
||||
position_ids_expanded = position_ids[None, None, :].float()
|
||||
|
||||
device_type = (
|
||||
hidden_states.device.type
|
||||
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
|
||||
else "cpu"
|
||||
)
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
sin = freqs.sin()
|
||||
cos = freqs.cos()
|
||||
# interleave sin and cos
|
||||
pos_embed = torch.stack([sin, cos], dim=-1)
|
||||
pos_embed = pos_embed.reshape(*pos_embed.shape[:-2], -1)
|
||||
|
||||
return pos_embed.to(dtype=hidden_states.dtype)
|
||||
|
||||
|
||||
class ParakeetEncoderFeedForward(nn.Module):
|
||||
def __init__(self, config: ParakeetEncoderConfig):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.attention_bias)
|
||||
self.activation = ACT2FN[config.hidden_act]
|
||||
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.attention_bias)
|
||||
self.activation_dropout = config.activation_dropout
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.activation(self.linear1(hidden_states))
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
|
||||
hidden_states = self.linear2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ParakeetEncoderConvolutionModule(FastSpeech2ConformerConvolutionModule):
|
||||
def __init__(self, config: ParakeetEncoderConfig, module_config=None):
|
||||
super().__init__(config, module_config)
|
||||
|
||||
|
||||
class ParakeetEncoderAttention(LlamaAttention):
|
||||
"""Multi-head attention with relative positional encoding. See section 3.3 of https://huggingface.co/papers/1901.02860."""
|
||||
|
||||
def __init__(self, config: ParakeetEncoderConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx=layer_idx)
|
||||
self.is_causal = False
|
||||
# W_{k,R} projection
|
||||
self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
|
||||
# global content bias
|
||||
self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
|
||||
# global positional bias
|
||||
self.bias_v = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Optional[torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
hidden_shape = (batch_size, seq_length, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
query_states_with_bias_u = query_states + self.bias_u.view(
|
||||
1, self.config.num_attention_heads, 1, self.head_dim
|
||||
)
|
||||
query_states_with_bias_v = query_states + self.bias_v.view(
|
||||
1, self.config.num_attention_heads, 1, self.head_dim
|
||||
)
|
||||
|
||||
relative_key_states = self.relative_k_proj(position_embeddings)
|
||||
relative_key_states = relative_key_states.view(batch_size, -1, self.config.num_attention_heads, self.head_dim)
|
||||
|
||||
# terms (b) and (d)
|
||||
matrix_bd = query_states_with_bias_v @ relative_key_states.permute(0, 2, 3, 1)
|
||||
matrix_bd = self._rel_shift(matrix_bd)
|
||||
matrix_bd = matrix_bd[..., :seq_length]
|
||||
matrix_bd = matrix_bd * self.scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
# here the original codebase uses -10000.0 rather than float("-inf") and then manual masked fill with 0.0s
|
||||
# see: https://github.com/NVIDIA-NeMo/NeMo/blob/8cfedd7203462cb251a914e700e5605444277561/nemo/collections/asr/parts/submodules/multi_head_attention.py#L320-L340
|
||||
# we rather went for a straight-forward approach with float("-inf")
|
||||
matrix_bd = matrix_bd.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
||||
|
||||
# will compute matrix_ac - terms (a) and (c) - and add matrix_bd
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query=query_states_with_bias_u,
|
||||
key=key_states,
|
||||
value=value_states,
|
||||
attention_mask=matrix_bd,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _rel_shift(self, attention_scores):
|
||||
"""Relative position shift for Shaw et al. style attention. See appendix B of https://huggingface.co/papers/1901.02860."""
|
||||
batch_size, num_heads, query_length, position_length = attention_scores.shape
|
||||
attention_scores = nn.functional.pad(attention_scores, pad=(1, 0))
|
||||
attention_scores = attention_scores.view(batch_size, num_heads, -1, query_length)
|
||||
attention_scores = attention_scores[:, :, 1:].view(batch_size, num_heads, query_length, position_length)
|
||||
return attention_scores
|
||||
|
||||
|
||||
class ParakeetEncoderSubsamplingConv2D(nn.Module):
|
||||
def __init__(self, config: ParakeetEncoderConfig):
|
||||
super().__init__()
|
||||
|
||||
self.kernel_size = config.subsampling_conv_kernel_size
|
||||
self.stride = config.subsampling_conv_stride
|
||||
self.channels = config.subsampling_conv_channels
|
||||
self.padding = (self.kernel_size - 1) // 2
|
||||
self.num_layers = int(math.log2(config.subsampling_factor))
|
||||
|
||||
# define layers
|
||||
self.layers = nn.ModuleList()
|
||||
self.layers.append(
|
||||
nn.Conv2d(1, self.channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
|
||||
)
|
||||
self.layers.append(nn.ReLU())
|
||||
for i in range(self.num_layers - 1):
|
||||
# depthwise conv
|
||||
self.layers.append(
|
||||
nn.Conv2d(
|
||||
self.channels,
|
||||
self.channels,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
groups=self.channels,
|
||||
)
|
||||
)
|
||||
# pointwise conv
|
||||
self.layers.append(nn.Conv2d(self.channels, self.channels, kernel_size=1))
|
||||
# activation
|
||||
self.layers.append(nn.ReLU())
|
||||
|
||||
out_length = config.num_mel_bins // (self.stride**self.num_layers)
|
||||
self.linear = nn.Linear(config.subsampling_conv_channels * out_length, config.hidden_size, bias=True)
|
||||
|
||||
def _get_output_length(self, input_lengths: torch.Tensor, conv_layer: nn.Conv2d):
|
||||
if hasattr(conv_layer, "stride") and conv_layer.stride != (1, 1):
|
||||
padding = conv_layer.padding
|
||||
kernel_size = conv_layer.kernel_size[0]
|
||||
stride = conv_layer.stride[0]
|
||||
|
||||
output_lengths = (input_lengths + padding[0] + padding[1] - kernel_size) // stride + 1
|
||||
return output_lengths
|
||||
|
||||
return input_lengths
|
||||
|
||||
def forward(self, input_features: torch.Tensor, attention_mask: torch.Tensor = None):
|
||||
hidden_states = input_features.unsqueeze(1)
|
||||
current_lengths = attention_mask.sum(-1) if attention_mask is not None else None
|
||||
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
# mask the hidden states
|
||||
if isinstance(layer, nn.Conv2d) and attention_mask is not None:
|
||||
current_lengths = self._get_output_length(current_lengths, layer)
|
||||
current_seq_length = hidden_states.shape[2]
|
||||
channel_mask = (
|
||||
torch.arange(current_seq_length, device=attention_mask.device) < current_lengths[:, None]
|
||||
)
|
||||
hidden_states *= channel_mask[:, None, :, None]
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(hidden_states.shape[0], hidden_states.shape[2], -1)
|
||||
hidden_states = self.linear(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ParakeetEncoderBlock(GradientCheckpointingLayer):
|
||||
def __init__(self, config: ParakeetEncoderConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.feed_forward1 = ParakeetEncoderFeedForward(config)
|
||||
self.self_attn = ParakeetEncoderAttention(config, layer_idx)
|
||||
self.conv = ParakeetEncoderConvolutionModule(config)
|
||||
self.feed_forward2 = ParakeetEncoderFeedForward(config)
|
||||
|
||||
self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size)
|
||||
self.norm_self_att = nn.LayerNorm(config.hidden_size)
|
||||
self.norm_conv = nn.LayerNorm(config.hidden_size)
|
||||
self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size)
|
||||
self.norm_out = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_embeddings: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
|
||||
hidden_states = residual + 0.5 * hidden_states # the conformer architecture uses a factor of 0.5
|
||||
|
||||
normalized_hidden_states = self.norm_self_att(hidden_states)
|
||||
attn_output, _ = self.self_attn(
|
||||
hidden_states=normalized_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
|
||||
hidden_states = hidden_states + conv_output
|
||||
|
||||
ff2_output = self.feed_forward2(self.norm_feed_forward2(hidden_states))
|
||||
hidden_states = hidden_states + 0.5 * ff2_output # the conformer architecture uses a factor of 0.5
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class ParakeetPreTrainedModel(PreTrainedModel):
|
||||
config: ParakeetCTCConfig
|
||||
base_model_prefix = "model"
|
||||
main_input_name = "input_features"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ParakeetEncoderBlock"]
|
||||
_supports_flat_attention_mask = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
# TODO: @eustlb, add support when flash attention supports custom attention bias
|
||||
_supports_flash_attn = False
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
"hidden_states": ParakeetEncoderBlock,
|
||||
"attentions": ParakeetEncoderAttention,
|
||||
}
|
||||
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
|
||||
if hasattr(self.config, "initializer_range"):
|
||||
std = self.config.initializer_range
|
||||
else:
|
||||
# 0.02 is the standard default value accross the library
|
||||
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
|
||||
|
||||
if isinstance(module, ParakeetEncoderAttention):
|
||||
# Initialize positional bias parameters
|
||||
module.bias_u.data.normal_(mean=0.0, std=std)
|
||||
module.bias_v.data.normal_(mean=0.0, std=std)
|
||||
|
||||
def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
|
||||
encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config
|
||||
|
||||
kernel_size = encoder_config.subsampling_conv_kernel_size
|
||||
stride = encoder_config.subsampling_conv_stride
|
||||
num_layers = int(math.log2(encoder_config.subsampling_factor))
|
||||
|
||||
all_paddings = (kernel_size - 1) // 2 * 2
|
||||
add_pad = all_paddings - kernel_size
|
||||
lengths = input_lengths
|
||||
|
||||
for _ in range(num_layers):
|
||||
lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + 1.0
|
||||
lengths = torch.floor(lengths)
|
||||
|
||||
return lengths.to(dtype=torch.int)
|
||||
|
||||
def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: Optional[int] = None):
|
||||
"""
|
||||
Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
|
||||
when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
|
||||
"""
|
||||
output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
|
||||
# Use target_length if provided, otherwise use max length in batch
|
||||
max_length = target_length if target_length is not None else output_lengths.max()
|
||||
attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None]
|
||||
return attention_mask
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Parakeet Encoder model, based on the [Fast Conformer architecture](https://huggingface.co/papers/2305.05084).
|
||||
"""
|
||||
)
|
||||
class ParakeetEncoder(ParakeetPreTrainedModel):
|
||||
config: ParakeetEncoderConfig
|
||||
base_model_prefix = "encoder"
|
||||
|
||||
def __init__(self, config: ParakeetEncoderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.dropout = config.dropout
|
||||
self.dropout_positions = config.dropout_positions
|
||||
self.layerdrop = config.layerdrop
|
||||
|
||||
self.input_scale = math.sqrt(config.hidden_size) if config.scale_input else 1.0
|
||||
self.subsampling = ParakeetEncoderSubsamplingConv2D(config)
|
||||
self.encode_positions = ParakeetEncoderRelPositionalEncoding(config)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[ParakeetEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@auto_docstring
|
||||
@check_model_inputs
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_features: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutput:
|
||||
r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, ParakeetEncoder
|
||||
>>> from datasets import load_dataset, Audio
|
||||
|
||||
>>> model_id = "nvidia/parakeet-ctc-1.1b"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_id)
|
||||
>>> encoder = ParakeetEncoder.from_pretrained(model_id)
|
||||
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
|
||||
|
||||
>>> inputs = processor(ds[0]["audio"]["array"])
|
||||
>>> encoder_outputs = encoder(**inputs)
|
||||
|
||||
>>> print(encoder_outputs.last_hidden_state.shape)
|
||||
```
|
||||
"""
|
||||
|
||||
hidden_states = self.subsampling(input_features, attention_mask)
|
||||
hidden_states = hidden_states * self.input_scale
|
||||
position_embeddings = self.encode_positions(hidden_states)
|
||||
|
||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||
position_embeddings = nn.functional.dropout(
|
||||
position_embeddings, p=self.dropout_positions, training=self.training
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
|
||||
attention_mask = attention_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1)
|
||||
attention_mask = attention_mask & attention_mask.transpose(1, 2)
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
|
||||
to_drop = False
|
||||
if self.training:
|
||||
dropout_probability = torch.rand([])
|
||||
if dropout_probability < self.layerdrop: # skip the layer
|
||||
to_drop = True
|
||||
|
||||
if not to_drop:
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return BaseModelOutput(last_hidden_state=hidden_states)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParakeetGenerateOutput(ModelOutput):
|
||||
"""
|
||||
Outputs of Parakeet models.
|
||||
|
||||
Args:
|
||||
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
|
||||
if all batches finished early due to the `eos_token_id`.
|
||||
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
|
||||
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
|
||||
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
|
||||
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
|
||||
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
|
||||
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
|
||||
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
|
||||
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
|
||||
"""
|
||||
|
||||
sequences: torch.LongTensor
|
||||
logits: Optional[tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
||||
hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Parakeet Encoder with a Connectionist Temporal Classification (CTC) head.
|
||||
"""
|
||||
)
|
||||
class ParakeetForCTC(ParakeetPreTrainedModel):
|
||||
config: ParakeetCTCConfig
|
||||
|
||||
def __init__(self, config: ParakeetCTCConfig):
|
||||
super().__init__(config)
|
||||
self.encoder = ParakeetEncoder(config.encoder_config)
|
||||
# Conv rather than linear to be consistent with NeMO decoding layer
|
||||
self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
|
||||
|
||||
self.post_init()
|
||||
|
||||
@auto_docstring
|
||||
@can_return_tuple
|
||||
def forward(
|
||||
self,
|
||||
input_features: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> CausalLMOutput:
|
||||
r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, ParakeetForCTC
|
||||
>>> from datasets import load_dataset, Audio
|
||||
|
||||
>>> model_id = "nvidia/parakeet-ctc-1.1b"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_id)
|
||||
>>> model = ParakeetForCTC.from_pretrained(model_id)
|
||||
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
|
||||
|
||||
>>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
|
||||
>>> outputs = model(**inputs)
|
||||
|
||||
>>> print(outputs.loss)
|
||||
```"""
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
input_features=input_features,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs.last_hidden_state
|
||||
logits = self.ctc_head(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# retrieve loss input_lengths from attention_mask
|
||||
attention_mask = (
|
||||
attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
|
||||
)
|
||||
input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
|
||||
|
||||
# assuming that padded tokens are filled with -100
|
||||
# when not being attended to
|
||||
labels_mask = labels != self.config.pad_token_id
|
||||
target_lengths = labels_mask.sum(-1)
|
||||
flattened_targets = labels.masked_select(labels_mask)
|
||||
|
||||
# ctc_loss doesn't support fp16
|
||||
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
|
||||
|
||||
with torch.backends.cudnn.flags(enabled=False):
|
||||
loss = nn.functional.ctc_loss(
|
||||
log_probs,
|
||||
flattened_targets,
|
||||
input_lengths,
|
||||
target_lengths,
|
||||
blank=self.config.pad_token_id,
|
||||
reduction=self.config.ctc_loss_reduction,
|
||||
zero_infinity=self.config.ctc_zero_infinity,
|
||||
)
|
||||
|
||||
return CausalLMOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_features: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict_in_generate: bool = False,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[ParakeetGenerateOutput, torch.LongTensor]:
|
||||
r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoProcessor, ParakeetForCTC
|
||||
>>> from datasets import load_dataset, Audio
|
||||
|
||||
>>> model_id = "nvidia/parakeet-ctc-1.1b"
|
||||
>>> processor = AutoProcessor.from_pretrained(model_id)
|
||||
>>> model = ParakeetForCTC.from_pretrained(model_id)
|
||||
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
|
||||
|
||||
>>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
|
||||
>>> predicted_ids = model.generate(**inputs)
|
||||
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
||||
|
||||
>>> print(transcription)
|
||||
```
|
||||
"""
|
||||
kwargs["return_dict"] = True
|
||||
outputs: CausalLMOutput = self.forward(
|
||||
input_features=input_features,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# greedy decoding
|
||||
sequences = outputs.logits.argmax(dim=-1)
|
||||
|
||||
# mask out padded tokens
|
||||
if attention_mask is not None:
|
||||
attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequences.shape[1])
|
||||
sequences[~attention_mask] = self.config.pad_token_id
|
||||
|
||||
if return_dict_in_generate:
|
||||
return ParakeetGenerateOutput(
|
||||
sequences=sequences,
|
||||
logits=outputs.logits,
|
||||
attentions=outputs.attentions,
|
||||
hidden_states=outputs.hidden_states,
|
||||
)
|
||||
|
||||
return sequences
|
||||
|
||||
|
||||
__all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"]
|
87
src/transformers/models/parakeet/processing_parakeet.py
Normal file
87
src/transformers/models/parakeet/processing_parakeet.py
Normal file
@ -0,0 +1,87 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...audio_utils import AudioInput, make_list_of_audio
|
||||
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
|
||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ParakeetProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"audio_kwargs": {
|
||||
"sampling_rate": 16000,
|
||||
"padding": "longest",
|
||||
},
|
||||
"text_kwargs": {
|
||||
"padding": True,
|
||||
"padding_side": "right",
|
||||
"add_special_tokens": False,
|
||||
},
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
}
|
||||
|
||||
|
||||
class ParakeetProcessor(ProcessorMixin):
|
||||
attributes = ["feature_extractor", "tokenizer"]
|
||||
feature_extractor_class = "ParakeetFeatureExtractor"
|
||||
tokenizer_class = "ParakeetTokenizerFast"
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
audio: AudioInput,
|
||||
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput], None] = None,
|
||||
sampling_rate: Optional[int] = None,
|
||||
**kwargs: Unpack[ParakeetProcessorKwargs],
|
||||
):
|
||||
audio = make_list_of_audio(audio)
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
ParakeetProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if sampling_rate is None:
|
||||
logger.warning_once(
|
||||
f"You've provided audio without specifying the sampling rate. It will be assumed to be {output_kwargs['audio_kwargs']['sampling_rate']}, which can result in silent errors."
|
||||
)
|
||||
elif sampling_rate != output_kwargs["audio_kwargs"]["sampling_rate"]:
|
||||
raise ValueError(
|
||||
f"The sampling rate of the audio ({sampling_rate}) does not match the sampling rate of the processor ({output_kwargs['audio_kwargs']['sampling_rate']}). Please provide resampled the audio to the expected sampling rate."
|
||||
)
|
||||
|
||||
if audio is not None:
|
||||
inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
|
||||
if text is not None:
|
||||
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
if text is None:
|
||||
return inputs
|
||||
else:
|
||||
inputs["labels"] = encodings["input_ids"]
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
feature_extractor_input_names = self.feature_extractor.model_input_names
|
||||
return feature_extractor_input_names + ["labels"]
|
||||
|
||||
|
||||
__all__ = ["ParakeetProcessor"]
|
@ -0,0 +1,54 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||
|
||||
|
||||
class ParakeetTokenizerFast(PreTrainedTokenizerFast):
|
||||
"""
|
||||
Inherits all methods from [`PreTrainedTokenizerFast`]. Users should refer to this superclass for more information regarding those methods,
|
||||
except for `_decode` which is overridden to adapt it to CTC decoding:
|
||||
1. Group consecutive tokens
|
||||
2. Filter out the blank token
|
||||
"""
|
||||
|
||||
def _decode(
|
||||
self,
|
||||
token_ids: Union[int, list[int]],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: Optional[bool] = None,
|
||||
group_tokens: bool = True,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
if group_tokens:
|
||||
token_ids = [token_group[0] for token_group in itertools.groupby(token_ids)]
|
||||
|
||||
# for CTC we filter out the blank token, which is the pad token
|
||||
token_ids = [token for token in token_ids if token != self.pad_token_id]
|
||||
|
||||
return super()._decode(
|
||||
token_ids=token_ids,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ParakeetTokenizerFast"]
|
1
tests/fixtures/parakeet/expected_results_batch.json
vendored
Normal file
1
tests/fixtures/parakeet/expected_results_batch.json
vendored
Normal file
File diff suppressed because one or more lines are too long
1
tests/fixtures/parakeet/expected_results_single.json
vendored
Normal file
1
tests/fixtures/parakeet/expected_results_single.json
vendored
Normal file
@ -0,0 +1 @@
|
||||
{"transcriptions": ["mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"], "scores": [-0.08922013640403748], "token_ids": [[1024, 1024, 1024, 1024, 1024, 1024, 19, 37, 132, 1024, 1024, 264, 128, 1024, 1024, 1024, 132, 1024, 58, 1024, 5, 645, 1024, 1000, 82, 52, 1024, 34, 1024, 5, 19, 68, 1007, 52, 1024, 235, 1024, 388, 1024, 27, 1024, 25, 1024, 56, 1024, 103, 1024, 1024, 727, 112, 1024, 22, 1024, 56, 1006, 1009, 405, 1024, 1024, 217, 1024, 1024, 95, 1003, 1024, 133, 1006, 1024, 1024, 1024, 1024, 1024, 1024, 1024]]}
|
0
tests/models/parakeet/__init__.py
Normal file
0
tests/models/parakeet/__init__.py
Normal file
197
tests/models/parakeet/test_feature_extraction_parakeet.py
Normal file
197
tests/models/parakeet/test_feature_extraction_parakeet.py
Normal file
@ -0,0 +1,197 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the Parakeet feature extraction."""
|
||||
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import ParakeetFeatureExtractor
|
||||
from transformers.testing_utils import require_torch
|
||||
from transformers.utils import is_datasets_available, is_torch_available
|
||||
|
||||
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_datasets_available():
|
||||
from datasets import load_dataset
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
def floats_list(shape, scale=1.0, rng=None, name=None):
|
||||
"""Creates a random float32 tensor"""
|
||||
if rng is None:
|
||||
rng = global_rng
|
||||
|
||||
values = []
|
||||
for batch_idx in range(shape[0]):
|
||||
values.append([])
|
||||
for _ in range(shape[1]):
|
||||
values[-1].append(rng.random() * scale)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class ParakeetFeatureExtractionTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=7,
|
||||
min_seq_length=400,
|
||||
max_seq_length=2000,
|
||||
feature_size=80,
|
||||
hop_length=160,
|
||||
win_length=400,
|
||||
n_fft=512,
|
||||
sampling_rate=16000,
|
||||
padding_value=0.0,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.min_seq_length = min_seq_length
|
||||
self.max_seq_length = max_seq_length
|
||||
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
|
||||
self.feature_size = feature_size
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.n_fft = n_fft
|
||||
self.sampling_rate = sampling_rate
|
||||
self.padding_value = padding_value
|
||||
|
||||
def prepare_feat_extract_dict(self):
|
||||
return {
|
||||
"feature_size": self.feature_size,
|
||||
"hop_length": self.hop_length,
|
||||
"win_length": self.win_length,
|
||||
"n_fft": self.n_fft,
|
||||
"sampling_rate": self.sampling_rate,
|
||||
"padding_value": self.padding_value,
|
||||
}
|
||||
|
||||
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTester.prepare_inputs_for_common
|
||||
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
|
||||
def _flatten(list_of_lists):
|
||||
return list(itertools.chain(*list_of_lists))
|
||||
|
||||
if equal_length:
|
||||
speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
|
||||
else:
|
||||
# make sure that inputs increase in size
|
||||
speech_inputs = [
|
||||
floats_list((x, self.feature_size))
|
||||
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
|
||||
]
|
||||
if numpify:
|
||||
speech_inputs = [np.asarray(x) for x in speech_inputs]
|
||||
return speech_inputs
|
||||
|
||||
|
||||
class ParakeetFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
|
||||
feature_extraction_class = ParakeetFeatureExtractor
|
||||
|
||||
def setUp(self):
|
||||
self.feat_extract_tester = ParakeetFeatureExtractionTester(self)
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id")[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
@require_torch
|
||||
def test_torch_integration(self):
|
||||
"""
|
||||
reproducer: https://gist.github.com/eustlb/c4a0999e54466b7e8d8b040d8e0900df
|
||||
"""
|
||||
# fmt: off
|
||||
EXPECTED_INPUT_FEATURES = torch.tensor(
|
||||
[
|
||||
0.60935932, 1.18187428, 1.29877627, 1.36461377, 1.09311509, 1.39821815,
|
||||
1.63753450, 1.37100816, 1.26510608, 1.70332706, 1.69067430, 1.28770995,
|
||||
1.52999651, 1.77962756, 1.71420062, 1.21944094, 1.30884087, 1.44343364,
|
||||
1.17694926, 1.42690814, 1.78877723, 1.68655288, 1.27155364, 1.66103351,
|
||||
1.75820673, 1.41575801, 1.40622294, 1.70603478, 1.63117850, 1.13353217,
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = ParakeetFeatureExtractor()
|
||||
inputs = feature_extractor(input_speech, return_tensors="pt")
|
||||
|
||||
self.assertEqual(inputs.input_features.shape, (1, 586, 80))
|
||||
torch.testing.assert_close(inputs.input_features[0, 100, :30], EXPECTED_INPUT_FEATURES, atol=1e-4, rtol=1e-4)
|
||||
|
||||
self.assertEqual(inputs.attention_mask.shape, (1, 586))
|
||||
# last frame should be masked
|
||||
self.assertEqual(inputs.attention_mask.sum(), 585)
|
||||
|
||||
@require_torch
|
||||
def test_torch_integration_batch(self):
|
||||
"""
|
||||
reproducer: https://gist.github.com/eustlb/c4a0999e54466b7e8d8b040d8e0900df
|
||||
"""
|
||||
# fmt: off
|
||||
EXPECTED_INPUT_FEATURES = torch.tensor(
|
||||
[
|
||||
[ 0.60935932, 1.18187428, 1.29877627, 1.36461377, 1.09311533,
|
||||
1.39821827, 1.63753450, 1.37100816, 1.26510608, 1.70332706,
|
||||
1.69067478, 1.28770995, 1.52999651, 1.77962780, 1.71420062,
|
||||
1.21944094, 1.30884087, 1.44343400, 1.17694926, 1.42690814,
|
||||
1.78877664, 1.68655288, 1.27155364, 1.66103351, 1.75820673,
|
||||
1.41575801, 1.40622294, 1.70603478, 1.63117862, 1.13353217],
|
||||
[ 0.58339858, 0.54317272, 0.46222782, 0.34154415, 0.17806509,
|
||||
0.32182255, 0.28909618, 0.02141305, -0.09710173, -0.35818669,
|
||||
-0.48172510, -0.52942866, -0.58029658, -0.70519227, -0.67929971,
|
||||
-0.54698551, -0.28611183, -0.24780270, -0.31363955, -0.41913241,
|
||||
-0.32394424, -0.44897896, -0.68657434, -0.62047797, -0.46886450,
|
||||
-0.65987164, -1.02435589, -0.58527517, -0.56095684, -0.73582536],
|
||||
[-0.91937613, -0.97933632, -1.06843162, -1.02642107, -0.94232899,
|
||||
-0.83840621, -0.82306921, -0.45763230, -0.45182887, -0.75917768,
|
||||
-0.42541453, -0.28512970, -0.39637473, -0.66478080, -0.68004298,
|
||||
-0.49690303, -0.31799242, -0.12917191, 0.13149273, 0.10163058,
|
||||
-0.40041649, 0.05001565, 0.23906317, 0.28816083, 0.14308788,
|
||||
-0.29588422, -0.05428466, 0.14418560, 0.28865972, -0.12138986],
|
||||
[ 0.73217624, 0.84484011, 0.79323846, 0.66315967, 0.41556871,
|
||||
0.88633078, 0.90718138, 0.91268104, 1.15920067, 1.26141894,
|
||||
1.10222173, 0.92990804, 0.96352047, 0.88142169, 0.56635213,
|
||||
0.71491158, 0.81301254, 0.67301887, 0.74780160, 0.64429688,
|
||||
0.22885245, 0.47035533, 0.46498337, 0.17544533, 0.44458991,
|
||||
0.79245001, 0.57207537, 0.85768145, 1.00491571, 0.93360955],
|
||||
[ 1.40496337, 1.32492661, 1.16519547, 0.98379827, 0.77614164,
|
||||
0.95871657, 0.81910741, 1.23010278, 1.33011520, 1.16538525,
|
||||
1.28319681, 1.45041633, 1.33421600, 0.91677380, 0.67107433,
|
||||
0.52890682, 0.82009870, 1.15821445, 1.15343642, 1.10958862,
|
||||
1.44962490, 1.44485891, 1.46043479, 1.90800595, 1.95863307,
|
||||
1.63670933, 1.49021459, 1.18701911, 0.74906683, 0.84700620]
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
input_speech = self._load_datasamples(5)
|
||||
feature_extractor = ParakeetFeatureExtractor()
|
||||
inputs = feature_extractor(input_speech, return_tensors="pt")
|
||||
|
||||
self.assertEqual(inputs.input_features.shape, (5, 2941, 80))
|
||||
torch.testing.assert_close(inputs.input_features[:, 100, :30], EXPECTED_INPUT_FEATURES, atol=1e-4, rtol=1e-4)
|
||||
|
||||
self.assertEqual(inputs.attention_mask.shape, (5, 2941))
|
||||
self.assertTrue(inputs.attention_mask.sum(dim=-1).tolist(), [585, 481, 1248, 990, 2940])
|
380
tests/models/parakeet/test_modeling_parakeet.py
Normal file
380
tests/models/parakeet/test_modeling_parakeet.py
Normal file
@ -0,0 +1,380 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Parakeet model."""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_datasets_available, is_torch_available
|
||||
from transformers.testing_utils import cleanup, require_torch, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_datasets_available():
|
||||
from datasets import Audio, load_dataset
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
ParakeetCTCConfig,
|
||||
ParakeetEncoder,
|
||||
ParakeetEncoderConfig,
|
||||
ParakeetForCTC,
|
||||
)
|
||||
|
||||
|
||||
class ParakeetEncoderModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=1024,
|
||||
is_training=True,
|
||||
hidden_size=64,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
dropout=0, # so gradient checkpointing doesn't fail
|
||||
conv_kernel_size=9,
|
||||
subsampling_factor=8,
|
||||
subsampling_conv_channels=32,
|
||||
use_bias=True,
|
||||
num_mel_bins=80,
|
||||
scale_input=True,
|
||||
):
|
||||
# testing suite parameters
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.num_mel_bins = num_mel_bins
|
||||
self.is_training = is_training
|
||||
|
||||
# config parameters
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.dropout = dropout
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.subsampling_conv_channels = subsampling_conv_channels
|
||||
self.use_bias = use_bias
|
||||
self.num_mel_bins = num_mel_bins
|
||||
self.scale_input = scale_input
|
||||
|
||||
# Calculate output sequence length after subsampling
|
||||
self.output_seq_length = seq_length // subsampling_factor
|
||||
self.encoder_seq_length = self.output_seq_length
|
||||
self.key_length = self.output_seq_length
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_features = floats_tensor([self.batch_size, self.seq_length, self.num_mel_bins])
|
||||
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_features, attention_mask
|
||||
|
||||
def get_config(self):
|
||||
return ParakeetEncoderConfig(
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
dropout=self.dropout,
|
||||
dropout_positions=self.dropout,
|
||||
layerdrop=self.dropout,
|
||||
activation_dropout=self.dropout,
|
||||
attention_dropout=self.dropout,
|
||||
conv_kernel_size=self.conv_kernel_size,
|
||||
subsampling_factor=self.subsampling_factor,
|
||||
subsampling_conv_channels=self.subsampling_conv_channels,
|
||||
use_bias=self.use_bias,
|
||||
num_mel_bins=self.num_mel_bins,
|
||||
scale_input=self.scale_input,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_features, attention_mask):
|
||||
model = ParakeetEncoder(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(input_features, attention_mask=attention_mask)
|
||||
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, config.hidden_size)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, input_features, attention_mask = self.prepare_config_and_inputs()
|
||||
inputs_dict = {
|
||||
"input_features": input_features,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def check_ctc_loss(self, config, input_values, *args):
|
||||
model = ParakeetForCTC(config=config)
|
||||
model.to(torch_device)
|
||||
|
||||
# make sure that dropout is disabled
|
||||
model.eval()
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
|
||||
labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0
|
||||
|
||||
model.config.ctc_loss_reduction = "sum"
|
||||
sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
|
||||
|
||||
model.config.ctc_loss_reduction = "mean"
|
||||
mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
|
||||
|
||||
self.parent.assertTrue(isinstance(sum_loss, float))
|
||||
self.parent.assertTrue(isinstance(mean_loss, float))
|
||||
|
||||
|
||||
@require_torch
|
||||
class ParakeetEncoderModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (ParakeetEncoder,) if is_torch_available() else ()
|
||||
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ParakeetEncoderModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=ParakeetEncoderConfig, has_text_modality=False)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="ParakeetEncoder does not use inputs_embeds")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
|
||||
class ParakeetForCTCModelTester:
|
||||
def __init__(self, parent, encoder_kwargs=None, is_training=True, vocab_size=128, pad_token_id=0):
|
||||
if encoder_kwargs is None:
|
||||
encoder_kwargs = {}
|
||||
|
||||
self.parent = parent
|
||||
self.encoder_model_tester = ParakeetEncoderModelTester(parent, **encoder_kwargs)
|
||||
self.is_training = is_training
|
||||
|
||||
self.batch_size = self.encoder_model_tester.batch_size
|
||||
self.output_seq_length = self.encoder_model_tester.output_seq_length
|
||||
self.num_hidden_layers = self.encoder_model_tester.num_hidden_layers
|
||||
self.seq_length = vocab_size
|
||||
self.hidden_size = self.encoder_model_tester.hidden_size
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.pad_token_id = pad_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
_, input_features, attention_mask = self.encoder_model_tester.prepare_config_and_inputs()
|
||||
config = self.get_config()
|
||||
return config, input_features, attention_mask
|
||||
|
||||
def get_config(self):
|
||||
return ParakeetCTCConfig.from_encoder_config(
|
||||
encoder_config=self.encoder_model_tester.get_config(),
|
||||
vocab_size=self.vocab_size,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_features, attention_mask):
|
||||
model = ParakeetForCTC(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
result = model(input_features, attention_mask=attention_mask)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.output_seq_length, self.vocab_size))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, input_features, attention_mask = self.prepare_config_and_inputs()
|
||||
inputs_dict = {
|
||||
"input_features": input_features,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def test_ctc_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.encoder_model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ParakeetForCTCModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (ParakeetForCTC,) if is_torch_available() else ()
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": ParakeetEncoder,
|
||||
"automatic-speech-recognition": ParakeetForCTC,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
||||
test_attention_outputs = False
|
||||
test_pruning = False
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_torch_exportable = True
|
||||
|
||||
_is_composite = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ParakeetForCTCModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=ParakeetCTCConfig)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
@unittest.skip(reason="ParakeetEncoder does not use inputs_embeds")
|
||||
def test_model_get_set_embeddings(self):
|
||||
pass
|
||||
|
||||
# Original function assumes vision+text model, so overwrite since Parakeet is audio+text
|
||||
# Below is modified from `tests/models/granite_speech/test_modeling_granite_speech.py`
|
||||
def test_sdpa_can_dispatch_composite_models(self):
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
if not self._is_composite:
|
||||
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_sdpa = model_class.from_pretrained(tmpdirname)
|
||||
model_sdpa = model_sdpa.eval().to(torch_device)
|
||||
|
||||
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
|
||||
model_eager = model_eager.eval().to(torch_device)
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
|
||||
@require_torch
|
||||
class ParakeetForCTCIntegrationTest(unittest.TestCase):
|
||||
_dataset = None
|
||||
|
||||
@classmethod
|
||||
def setUp(cls):
|
||||
cls.checkpoint_name = "nvidia/parakeet-ctc-1.1b"
|
||||
cls.dtype = torch.bfloat16
|
||||
cls.processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
@classmethod
|
||||
def _load_dataset(cls):
|
||||
# Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process.
|
||||
if cls._dataset is None:
|
||||
cls._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
cls._dataset = cls._dataset.cast_column(
|
||||
"audio", Audio(sampling_rate=cls.processor.feature_extractor.sampling_rate)
|
||||
)
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
self._load_dataset()
|
||||
ds = self._dataset
|
||||
speech_samples = ds.sort("id")[:num_samples]["audio"]
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
@slow
|
||||
def test_1b_model_integration(self):
|
||||
"""
|
||||
bezzam reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_single-py
|
||||
eustlb reproducer: https://gist.github.com/eustlb/6e9e3aa85de3f7c340ec3c36e65f2fe6
|
||||
"""
|
||||
RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_single.json"
|
||||
with open(RESULTS_PATH, "r") as f:
|
||||
raw_data = json.load(f)
|
||||
EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"])
|
||||
EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"]
|
||||
|
||||
samples = self._load_datasamples(1)
|
||||
model = ParakeetForCTC.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device)
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
|
||||
# -- apply
|
||||
inputs = self.processor(samples)
|
||||
inputs.to(torch_device, dtype=self.dtype)
|
||||
predicted_ids = model.generate(**inputs)
|
||||
torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS)
|
||||
predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
@slow
|
||||
def test_1b_model_integration_batched(self):
|
||||
"""
|
||||
bezzam reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_batched-py
|
||||
eustlb reproducer: https://gist.github.com/eustlb/575b5da58de34a70116a1955b1183596
|
||||
"""
|
||||
|
||||
RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_batch.json"
|
||||
with open(RESULTS_PATH, "r") as f:
|
||||
raw_data = json.load(f)
|
||||
EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"])
|
||||
EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"]
|
||||
|
||||
samples = self._load_datasamples(5)
|
||||
model = ParakeetForCTC.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device)
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
|
||||
# -- apply
|
||||
inputs = self.processor(samples)
|
||||
inputs.to(torch_device, dtype=self.dtype)
|
||||
predicted_ids = model.generate(**inputs)
|
||||
torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS)
|
||||
predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS)
|
49
tests/models/parakeet/test_processing_parakeet.py
Normal file
49
tests/models/parakeet/test_processing_parakeet.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoProcessor, ParakeetProcessor
|
||||
from transformers.testing_utils import require_torch, require_torchaudio
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class ParakeetProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = ParakeetProcessor
|
||||
text_input_name = "labels"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.tmpdirname = tempfile.mkdtemp()
|
||||
cls.checkpoint = "nvidia/parakeet-ctc-1.1b"
|
||||
processor = ParakeetProcessor.from_pretrained(cls.checkpoint)
|
||||
processor.save_pretrained(cls.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_feature_extractor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).feature_extractor
|
||||
|
||||
def get_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
|
53
tests/models/parakeet/test_tokenization_parakeet.py
Normal file
53
tests/models/parakeet/test_tokenization_parakeet.py
Normal file
@ -0,0 +1,53 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the ParakeetCTC tokenizer."""
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers.models.parakeet import ParakeetTokenizerFast
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
|
||||
|
||||
class ParakeetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
slow_tokenizer_class = None
|
||||
rust_tokenizer_class = ParakeetTokenizerFast
|
||||
tokenizer_class = ParakeetTokenizerFast
|
||||
test_slow_tokenizer = False
|
||||
test_rust_tokenizer = True
|
||||
from_pretrained_id = "nvidia/parakeet-ctc-1.1b"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
tokenizer = ParakeetTokenizerFast.from_pretrained("nvidia/parakeet-ctc-1.1b")
|
||||
tokenizer.save_pretrained(cls.tmpdirname)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This test does not apply to ParakeetTokenizerFast. More details in the test docstring itself."
|
||||
)
|
||||
def test_added_tokens_do_lower_case(self):
|
||||
"""
|
||||
Precompiled normalization from sentencepiece is `nmt_nfkc_cf` that includes lowercasing. Yet, ParakeetTokenizerFast does not have a do_lower_case attribute.
|
||||
This result in the test failing.
|
||||
"""
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="This needs a slow tokenizer. Parakeet does not have one!")
|
||||
def test_encode_decode_with_spaces(self):
|
||||
return
|
||||
|
||||
@unittest.skip(reason="ParakeetTokenizerFast doesn't have tokenizer_file in its signature.")
|
||||
def test_rust_tokenizer_signature(self):
|
||||
pass
|
@ -55,6 +55,7 @@ from transformers.models.auto.modeling_auto import (
|
||||
MODEL_FOR_BACKBONE_MAPPING_NAMES,
|
||||
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
MODEL_FOR_CTC_MAPPING_NAMES,
|
||||
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
|
||||
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
|
||||
@ -657,6 +658,7 @@ class ModelTesterMixin:
|
||||
*get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
|
||||
*get_values(MODEL_FOR_CTC_MAPPING_NAMES),
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
|
@ -315,6 +315,7 @@ OBJECTS_TO_IGNORE = {
|
||||
"OpenAIGPTTokenizerFast",
|
||||
"OpenLlamaConfig",
|
||||
"PLBartConfig",
|
||||
"ParakeetCTCConfig",
|
||||
"PegasusConfig",
|
||||
"PegasusTokenizer",
|
||||
"PegasusTokenizerFast",
|
||||
|
Reference in New Issue
Block a user