Add Parakeet (#39062)

* first commit

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* update to handle masking for bs>1

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* Add tests and docs

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* update model ids

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* update docs and improve style

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* update librosa location

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* import guard torch too

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* ruff code checks fix

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* ruff format check

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* updated to parakeet names

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* update script

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* Add tokenizer decoding

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* Remove other model dependency

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* clean tests

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* fix tests

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* linting

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* fix ruff lint warnings

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* move to seperate folders

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* add parakeet ctc model code

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* simplify encoder structure

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* update documentation

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* add parakeet to toctree

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* fix tests

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* add parakeet doc

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* Address comments

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* Update featurizer to compute lens directly

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* fix ruff tests

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* fix encoding format

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* fix minor ctc decoding

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

* revert modular_model_converter.py changes

* revert check_config_attributes.py changes

* refactor: fastconformer & parakeet_ctc -> parakeet

* modeling update

* test update

* propagate feature extractor updates

* propagate doc changes

* propagate doc changes

* propagate tokenization changes

* propagate conversion changes

* remove fastconformer tests

* remove modular

* update processor

* update processor

* tset update

* diverse fixes

* 100% macthing greedy batched

* Update conversion script.

* Refactor docs.

* Reafactor auto loading.

* Refactor and fix tokenization and processing.

* Update integration test.

* Modeling fixes:
- ensure correct attention mask shape
- ensure layer drop returns valid output
- correct blank token ID when computing CTC loss

* Format and repo consistency.

* Update model doc.

* Fix feature extraction tests.

* Fix (most) tokenizer tests.

* Add pipeline example.

* Fixes

* Use eager_attention_forward from Llama.

* Small tweaks.

* Replace Sequential with ModuleList

* Add check if not all layers copied

* Clean tokenizer.

* Standardize FastSpeech2ConformerConvolutionModule for Parakeet.

* Switch to modular for modeling and processing.

* Add processor tests.

* Fix modeling tests.

* Formating and docstrings.

* Add `return_attention_mask` like other feature extractors.

* clean up after merging main.

* nits on modeling

* configuration update

* nit

* simplification: use PretrainedTokenizerFast, simplify processor

* add dtype arg to mel_filter_bank

* feature extraction: simplify!

* modeling update

* change to ParakeetTokenizerFast

* correct attention mask handling

* auto update

* proc update

* test update

* feature extraction fixes

* modeling update

* conversion script update

* udpate tests feature integration

* update tokenization and tests

* processor tests

* revert audio_utils

* config docstring update

* blank_token -> pad_token

* modeling udpate

* doc update

* fix tests

* fix test

* fix tests

* address review comments

* add comment

* add comment

* explicitly not support flash

* atttention straightforward masking

* fix

* tokenizer update: skipping blank tokens by default

* doc update

* fix max_positions_embeddings handling

* nits

* change atol faeture extraction integration tests

* doc update + fix loss

* doc update

* nit

* update integration test for A10

* repo id name

* nit

---------

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Co-authored-by: Eustache Le Bihan <eulebihan@gmail.com>
Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
Co-authored-by: Eric B <ebezzam@gmail.com>
This commit is contained in:
Nithin Rao
2025-09-25 09:52:24 -04:00
committed by GitHub
parent 1dd22a234c
commit a579de7f5e
26 changed files with 3372 additions and 8 deletions

View File

@ -935,6 +935,8 @@
title: MusicGen
- local: model_doc/musicgen_melody
title: MusicGen Melody
- local: model_doc/parakeet
title: Parakeet
- local: model_doc/pop2piano
title: Pop2Piano
- local: model_doc/seamless_m4t

View File

@ -0,0 +1,220 @@
<!--Copyright 2025 The NVIDIA NeMo Team and The HuggingFace Inc. team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
# Parakeet
## Overview
Parakeet models, [introduced by NVIDIA NeMo](https://developer.nvidia.com/blog/pushing-the-boundaries-of-speech-recognition-with-nemo-parakeet-asr-models/), are models that combine a [Fast Conformer](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/asr/models.html#fast-conformer) encoder with connectionist temporal classification (CTC), recurrent neural network transducer (RNNT) or token and duration transducer (TDT) decoder for automatic speech recognition.
**Model Architecture**
- **Fast Conformer Encoder**: A linearly scalable Conformer architecture that processes mel-spectrogram features and reduces sequence length through subsampling. This is more efficient version of the Conformer Encoder found in [FastSpeech2Conformer](./fastspeech2_conformer.md) (see [`ParakeetEncoder`] for the encoder implementation and details).
- [**ParakeetForCTC**](#parakeetforctc): a Fast Conformer Encoder + a CTC decoder
- **CTC Decoder**: Simple but effective decoder consisting of:
- 1D convolution projection from encoder hidden size to vocabulary size (for optimal NeMo compatibility).
- CTC loss computation for training.
- Greedy CTC decoding for inference.
The original implementation can be found in [NVIDIA NeMo](https://github.com/NVIDIA/NeMo).
Model checkpoints are to be found under [the NVIDIA organization](https://huggingface.co/nvidia/models?search=parakeet).
This model was contributed by [Nithin Rao Koluguri](https://huggingface.co/nithinraok), [Eustache Le Bihan](https://huggingface.co/eustlb) and [Eric Bezzam](https://huggingface.co/bezzam).
## Usage
### Basic usage
<hfoptions id="usage">
<hfoption id="Pipeline">
```py
from transformers import pipeline
pipe = pipeline("automatic-speech-recognition", model="nvidia/parakeet-ctc-1.1b")
out = pipe("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3")
print(out)
```
</hfoption>
<hfoption id="AutoModel">
```py
from transformers import AutoModelForCTC, AutoProcessor
from datasets import load_dataset, Audio
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map=device)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
speech_samples = [el['array'] for el in ds["audio"][:5]]
inputs = processor(speech_samples, sampling_rate=processor.feature_extractor.sampling_rate)
inputs.to(model.device, dtype=model.dtype)
outputs = model.generate(**inputs)
print(processor.batch_decode(outputs))
```
</hfoption>
</hfoptions>
### Making The Model Go Brrr
Parakeet supports full-graph compilation with CUDA graphs! This optimization is most effective when you know the maximum audio length you want to transcribe. The key idea is using static input shapes to avoid recompilation. For example, if you know your audio will be under 30 seconds, you can use the processor to pad all inputs to 30 seconds, preparing consistent input features and attention masks. See the example below!
```python
from transformers import AutoModelForCTC, AutoProcessor
from datasets import load_dataset, Audio
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map=device)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
speech_samples = [el['array'] for el in ds["audio"][:5]]
# Compile the generate method with fullgraph and CUDA graphs
model.generate = torch.compile(model.generate, fullgraph=True, mode="reduce-overhead")
# let's define processor kwargs to pad to 30 seconds
processor_kwargs = {
"padding": "max_length",
"max_length": 30 * processor.feature_extractor.sampling_rate,
}
# Define a timing context using CUDA events
class TimerContext:
def __init__(self, name="Execution"):
self.name = name
self.start_event = None
self.end_event = None
def __enter__(self):
# Use CUDA events for more accurate GPU timing
self.start_event = torch.cuda.Event(enable_timing=True)
self.end_event = torch.cuda.Event(enable_timing=True)
self.start_event.record()
return self
def __exit__(self, *args):
self.end_event.record()
torch.cuda.synchronize()
elapsed_time = self.start_event.elapsed_time(self.end_event) / 1000.0
print(f"{self.name} time: {elapsed_time:.4f} seconds")
inputs = processor(speech_samples[0], **processor_kwargs)
inputs.to(device, dtype=model.dtype)
print("\n" + "="*50)
print("First generation - compiling...")
# Generate with the compiled model
with TimerContext("First generation"):
outputs = model.generate(**inputs)
print(processor.batch_decode(outputs))
inputs = processor(speech_samples[1], **processor_kwargs)
inputs.to(device, dtype=model.dtype)
print("\n" + "="*50)
print("Second generation - recording CUDA graphs...")
with TimerContext("Second generation"):
outputs = model.generate(**inputs)
print(processor.batch_decode(outputs))
inputs = processor(speech_samples[2], **processor_kwargs)
inputs.to(device, dtype=model.dtype)
print("\n" + "="*50)
print("Third generation - fast !!!")
with TimerContext("Third generation"):
outputs = model.generate(**inputs)
print(processor.batch_decode(outputs))
inputs = processor(speech_samples[3], **processor_kwargs)
inputs.to(device, dtype=model.dtype)
print("\n" + "="*50)
print("Fourth generation - still fast !!!")
with TimerContext("Fourth generation"):
outputs = model.generate(**inputs)
print(processor.batch_decode(outputs))
```
### Training
```python
from transformers import AutoModelForCTC, AutoProcessor
from datasets import load_dataset, Audio
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
model = AutoModelForCTC.from_pretrained("nvidia/parakeet-ctc-1.1b", dtype="auto", device_map=device)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
speech_samples = [el['array'] for el in ds["audio"][:5]]
text_samples = [el for el in ds["text"][:5]]
# passing `text` to the processor will prepare inputs' `labels` key
inputs = processor(audio=speech_samples, text=text_samples, sampling_rate=processor.feature_extractor.sampling_rate)
inputs.to(device, dtype=model.dtype)
outputs = model(**inputs)
outputs.loss.backward()
```
## ParakeetTokenizerFast
[[autodoc]] ParakeetTokenizerFast
## ParakeetFeatureExtractor
[[autodoc]] ParakeetFeatureExtractor
- __call__
## ParakeetProcessor
[[autodoc]] ParakeetProcessor
- __call__
- batch_decode
- decode
## ParakeetEncoderConfig
[[autodoc]] ParakeetEncoderConfig
## ParakeetCTCConfig
[[autodoc]] ParakeetCTCConfig
## ParakeetEncoder
[[autodoc]] ParakeetEncoder
## ParakeetForCTC
[[autodoc]] ParakeetForCTC

View File

@ -1540,6 +1540,54 @@ class HeliumConverter(SpmConverter):
)
class ParakeetConverter(SpmConverter):
handle_byte_fallback = True
def __init__(self, vocab_file=None, *args):
self.vocab_file = vocab_file
requires_backends(self, "protobuf")
Converter.__init__(self, vocab_file)
model_pb2 = import_protobuf()
m = model_pb2.ModelProto()
with open(vocab_file, "rb") as f:
m.ParseFromString(f.read())
self.proto = m
def tokenizer(self, proto):
vocab_scores = self.vocab(proto)
_, merges = self.SpmExtractor(self.vocab_file).extract(vocab_scores)
bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
tokenizer = Tokenizer(
BPE(
bpe_vocab,
merges,
unk_token=proto.trainer_spec.unk_piece,
fuse_unk=True,
byte_fallback=self.handle_byte_fallback,
dropout=None,
)
)
# Add user defined symbols and control tokens from sentencepiece model
spm_added_tokens = [
(id, p.piece, p.type == 3 or p.piece in self.special_tokens)
for id, p in enumerate(proto.pieces)
if p.type in [3, 4]
]
tokenizer.add_tokens(
[
AddedToken(token, normalized=False, special=special)
for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
]
)
return tokenizer
# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
def bytes_to_unicode():
"""

View File

@ -253,6 +253,7 @@ if TYPE_CHECKING:
from .owlv2 import *
from .owlvit import *
from .paligemma import *
from .parakeet import *
from .patchtsmixer import *
from .patchtst import *
from .pegasus import *

View File

@ -296,6 +296,8 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("owlv2", "Owlv2Config"),
("owlvit", "OwlViTConfig"),
("paligemma", "PaliGemmaConfig"),
("parakeet_ctc", "ParakeetCTCConfig"),
("parakeet_encoder", "ParakeetEncoderConfig"),
("patchtsmixer", "PatchTSMixerConfig"),
("patchtst", "PatchTSTConfig"),
("pegasus", "PegasusConfig"),
@ -745,6 +747,9 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
("owlv2", "OWLv2"),
("owlvit", "OWL-ViT"),
("paligemma", "PaliGemma"),
("parakeet", "Parakeet"),
("parakeet_ctc", "Parakeet"),
("parakeet_encoder", "ParakeetEncoder"),
("patchtsmixer", "PatchTSMixer"),
("patchtst", "PatchTST"),
("pegasus", "Pegasus"),
@ -984,6 +989,8 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict[str, str](
("blip_2_qformer", "blip_2"),
("fastspeech2_conformer_with_hifigan", "fastspeech2_conformer"),
("perception_encoder", "perception_lm"),
("parakeet_encoder", "parakeet"),
("parakeet_ctc", "parakeet"),
]
)

View File

@ -81,6 +81,8 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("moshi", "EncodecFeatureExtractor"),
("nat", "ViTFeatureExtractor"),
("owlvit", "OwlViTFeatureExtractor"),
("parakeet_ctc", "ParakeetFeatureExtractor"),
("parakeet_encoder", "ParakeetFeatureExtractor"),
("perceiver", "PerceiverFeatureExtractor"),
("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
("poolformer", "PoolFormerFeatureExtractor"),

View File

@ -295,6 +295,8 @@ MODEL_MAPPING_NAMES = OrderedDict(
("owlv2", "Owlv2Model"),
("owlvit", "OwlViTModel"),
("paligemma", "PaliGemmaModel"),
("parakeet_ctc", "ParakeetForCTC"),
("parakeet_encoder", "ParakeetEncoder"),
("patchtsmixer", "PatchTSMixerModel"),
("patchtst", "PatchTSTModel"),
("pegasus", "PegasusModel"),
@ -1601,6 +1603,7 @@ MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
("data2vec-audio", "Data2VecAudioForCTC"),
("hubert", "HubertForCTC"),
("mctct", "MCTCTForCTC"),
("parakeet_ctc", "ParakeetForCTC"),
("sew", "SEWForCTC"),
("sew-d", "SEWDForCTC"),
("unispeech", "UniSpeechForCTC"),

View File

@ -502,6 +502,7 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("parakeet", ("ParakeetCTCTokenizer", None)),
(
"pegasus",
(

View File

@ -21,6 +21,7 @@ from typing import Optional, Union
import torch
from torch import nn
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import ModelOutput, auto_docstring, logging
@ -472,24 +473,37 @@ class FastSpeech2ConformerAttention(nn.Module):
class FastSpeech2ConformerConvolutionModule(nn.Module):
def __init__(self, config: FastSpeech2ConformerConfig, module_config):
def __init__(self, config: FastSpeech2ConformerConfig, module_config=None):
"""
Args:
config (FastSpeech2ConformerConfig): Configuration for the model.
module_config (dict): Configuration for the module (e.g., encoder or decoder).
"""
super().__init__()
# kernel_size should be an odd number for 'SAME' padding
channels = config.hidden_size
kernel_size = module_config["kernel_size"]
# kernel_size should be an odd number for 'SAME' padding
if module_config is None:
# e.g. using `ParakeetEncoderConfig` in src/transformers/models/parakeet/configuration_parakeet.py
kernel_size = config.conv_kernel_size
self.activation = ACT2FN[getattr(config, "hidden_act", "silu")]
else:
kernel_size = module_config["kernel_size"]
self.activation = ACT2FN[module_config.get("activation", "silu")]
self.padding = (kernel_size - 1) // 2
self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True)
self.depthwise_conv = nn.Conv1d(
channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=True
channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, hidden_states):
def forward(self, hidden_states, attention_mask=None):
"""
Compute convolution module.
Args:
hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.
Returns:
`torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
@ -503,12 +517,15 @@ class FastSpeech2ConformerConvolutionModule(nn.Module):
# (batch_size, channel, dim)
hidden_states = nn.functional.glu(hidden_states, dim=1)
# Apply padding mask before convolution
if attention_mask is not None:
all_masked_rows = torch.all(~attention_mask, dim=-1)
hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
# 1D Depthwise Conv
hidden_states = self.depthwise_conv(hidden_states)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states * torch.sigmoid(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.pointwise_conv2(hidden_states)
return hidden_states.transpose(1, 2)

View File

@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_parakeet import *
from .feature_extraction_parakeet import *
from .modeling_parakeet import *
from .tokenization_parakeet_fast import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,235 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parakeet model configuration."""
from typing import Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class ParakeetEncoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ParakeetEncoder`]. It is used to instantiate a
`ParakeetEncoder` model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 1024):
Dimension of the layers and the hidden states.
num_hidden_layers (`int`, *optional*, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 4096):
Dimension of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the encoder and pooler.
attention_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the attention layers.
conv_kernel_size (`int`, *optional*, defaults to 9):
The kernel size of the convolution layers in the Conformer block.
subsampling_factor (`int`, *optional*, defaults to 8):
The factor by which the input sequence is subsampled.
subsampling_conv_channels (`int`, *optional*, defaults to 256):
The number of channels in the subsampling convolution layers.
num_mel_bins (`int`, *optional*, defaults to 80):
Number of mel features.
subsampling_conv_kernel_size (`int`, *optional*, defaults to 3):
The kernel size of the subsampling convolution layers.
subsampling_conv_stride (`int`, *optional*, defaults to 2):
The stride of the subsampling convolution layers.
dropout (`float`, *optional*, defaults to 0.1):
The dropout ratio for all fully connected layers in the embeddings, encoder, and pooler.
dropout_positions (`float`, *optional*, defaults to 0.0):
The dropout ratio for the positions in the input sequence.
layerdrop (`float`, *optional*, defaults to 0.1):
The dropout ratio for the layers in the encoder.
activation_dropout (`float`, *optional*, defaults to 0.1):
The dropout ratio for activations inside the fully connected layer.
attention_dropout (`float`, *optional*, defaults to 0.1):
The dropout ratio for the attention layers.
max_position_embeddings (`int`, *optional*, defaults to 5000):
The maximum sequence length that this model might ever be used with.
scale_input (`bool`, *optional*, defaults to `True`):
Whether to scale the input embeddings.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Example:
```python
>>> from transformers import ParakeetEncoderModel, ParakeetEncoderConfig
>>> # Initializing a `ParakeetEncoder` configuration
>>> configuration = ParakeetEncoderConfig()
>>> # Initializing a model from the configuration
>>> model = ParakeetEncoderModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
This configuration class is based on the ParakeetEncoder architecture from NVIDIA NeMo. You can find more details
and pre-trained models at [nvidia/parakeet-ctc-1.1b](https://huggingface.co/nvidia/parakeet-ctc-1.1b).
"""
model_type = "parakeet_encoder"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=8,
intermediate_size=4096,
hidden_act="silu",
attention_bias=True,
conv_kernel_size=9,
subsampling_factor=8,
subsampling_conv_channels=256,
num_mel_bins=80,
subsampling_conv_kernel_size=3,
subsampling_conv_stride=2,
dropout=0.1,
dropout_positions=0.0,
layerdrop=0.1,
activation_dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=5000,
scale_input=True,
initializer_range=0.02,
**kwargs,
):
super().__init__(
**kwargs,
)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_attention_heads # LlamaAttention compatibility
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.attention_bias = attention_bias
if (conv_kernel_size - 1) % 2 != 0:
raise ValueError(f"conv_kernel_size must be odd, got {conv_kernel_size}")
self.conv_kernel_size = conv_kernel_size
self.subsampling_conv_kernel_size = subsampling_conv_kernel_size
self.subsampling_conv_stride = subsampling_conv_stride
self.subsampling_factor = subsampling_factor
self.subsampling_conv_channels = subsampling_conv_channels
self.num_mel_bins = num_mel_bins
self.dropout = dropout
self.dropout_positions = dropout_positions
self.layerdrop = layerdrop
self.activation_dropout = activation_dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.scale_input = scale_input
self.initializer_range = initializer_range
class ParakeetCTCConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`ParakeetForCTC`]. It is used to instantiate a
Parakeet CTC model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 1025):
Vocabulary size of the model.
ctc_loss_reduction (`str`, *optional*, defaults to `"mean"`):
Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an
instance of [`ParakeetForCTC`].
ctc_zero_infinity (`bool`, *optional*, defaults to `True`):
Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly
occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance
of [`ParakeetForCTC`].
encoder_config (`Union[dict, ParakeetEncoderConfig]`, *optional*):
The config object or dictionary of the encoder.
pad_token_id (`int`, *optional*, defaults to 1024):
Padding token id. Also used as blank token id.
Example:
```python
>>> from transformers import ParakeetForCTC, ParakeetCTCConfig
>>> # Initializing a Parakeet configuration
>>> configuration = ParakeetCTCConfig()
>>> # Initializing a model from the configuration
>>> model = ParakeetForCTC(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
This configuration class is based on the Parakeet CTC architecture from NVIDIA NeMo. You can find more details
and pre-trained models at [nvidia/parakeet-ctc-1.1b](https://huggingface.co/nvidia/parakeet-ctc-1.1b).
"""
model_type = "parakeet_ctc"
sub_configs = {"encoder_config": ParakeetEncoderConfig}
def __init__(
self,
vocab_size=1025,
ctc_loss_reduction="mean",
ctc_zero_infinity=True,
encoder_config: Union[dict, ParakeetEncoderConfig] = None,
pad_token_id=1024,
**kwargs,
):
self.vocab_size = vocab_size
self.ctc_loss_reduction = ctc_loss_reduction
self.ctc_zero_infinity = ctc_zero_infinity
if isinstance(encoder_config, dict):
self.encoder_config = ParakeetEncoderConfig(**encoder_config)
elif encoder_config is None:
self.encoder_config = ParakeetEncoderConfig()
self.encoder_config = self.encoder_config
self.initializer_range = self.encoder_config.initializer_range
super().__init__(
pad_token_id=pad_token_id,
**kwargs,
)
@classmethod
def from_encoder_config(cls, encoder_config: ParakeetEncoderConfig, **kwargs):
r"""
Instantiate a [`ParakeetCTCConfig`] (or a derived class) from parakeet encoder model configuration.
Returns:
[`ParakeetCTCConfig`]: An instance of a configuration object
"""
return cls(encoder_config=encoder_config.to_dict(), **kwargs)
__all__ = ["ParakeetCTCConfig", "ParakeetEncoderConfig"]

View File

@ -0,0 +1,315 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import os
import re
import tarfile
import torch
import yaml
from tokenizers import AddedToken
from transformers import (
ParakeetCTCConfig,
ParakeetFeatureExtractor,
ParakeetForCTC,
ParakeetProcessor,
ParakeetTokenizerFast,
)
from transformers.convert_slow_tokenizer import ParakeetConverter
from transformers.utils.hub import cached_file
NEMO_TO_HF_WEIGHT_MAPPING = {
r"encoder\.pre_encode\.conv\.": r"encoder.subsampling.layers.",
r"encoder\.pre_encode\.out\.": r"encoder.subsampling.linear.",
r"encoder\.pos_enc\.": r"encoder.encode_positions.",
r"encoder\.layers\.(\d+)\.conv\.batch_norm\.": r"encoder.layers.\1.conv.norm.",
r"decoder\.decoder_layers\.0\.(weight|bias)": r"ctc_head.\1",
r"linear_([kv])": r"\1_proj",
r"linear_out": r"o_proj",
r"linear_q": r"q_proj",
r"pos_bias_([uv])": r"bias_\1",
r"linear_pos": r"relative_k_proj",
}
def convert_key(key, mapping):
for pattern, replacement in mapping.items():
key = re.sub(pattern, replacement, key)
return key
def extract_nemo_archive(nemo_file_path: str, extract_dir: str) -> dict[str, str]:
"""
Extract .nemo file (tar archive) and return paths to important files.
Args:
nemo_file_path: Path to .nemo file
extract_dir: Directory to extract to
Returns:
Dictionary with paths to model.pt, model_config.yaml, etc.
"""
print(f"Extracting NeMo archive: {nemo_file_path}")
with tarfile.open(nemo_file_path, "r", encoding="utf-8") as tar:
tar.extractall(extract_dir)
# Log all extracted files for debugging
all_files = []
for root, dirs, files in os.walk(extract_dir):
for file in files:
file_path = os.path.join(root, file)
all_files.append(file_path)
print(f"All extracted files: {[os.path.basename(f) for f in all_files]}")
# Find important files with more robust detection
model_files = {}
for root, dirs, files in os.walk(extract_dir):
for file in files:
file_path = os.path.join(root, file)
file_lower = file.lower()
# Look for model weights with various common names
if (
file.endswith(".pt")
or file.endswith(".pth")
or file.endswith(".ckpt")
or file.endswith(".bin")
or "model" in file_lower
and ("weight" in file_lower or "state" in file_lower)
or file_lower == "model.pt"
or file_lower == "pytorch_model.bin"
or file_lower == "model_weights.ckpt"
):
model_files["model_weights"] = file_path
print(f"Found model weights: {file}")
# Look for config files
elif (
file == "model_config.yaml"
or file == "config.yaml"
or (file.endswith(".yaml") and "config" in file_lower)
):
if "model_config" not in model_files: # Prefer model_config.yaml
model_files["model_config"] = file_path
print(f"Found config file: {file}")
if file == "model_config.yaml":
model_files["model_config"] = file_path # Override with preferred name
# Look for vocabulary files
elif (
file.endswith(".vocab")
or file.endswith(".model")
or file.endswith(".txt")
or ("tokenizer" in file_lower and (file.endswith(".vocab") or file.endswith(".model")))
):
# Prefer .vocab files over others
if "tokenizer_model_file" not in model_files or file.endswith(".model"):
model_files["tokenizer_model_file"] = file_path
print(f"Found tokenizer model file: {file}")
else:
print(f"Found additional vocabulary file (using existing): {file}")
print(f"Found model files: {list(model_files.keys())}")
# Validate that we found the required files
if "model_weights" not in model_files:
raise FileNotFoundError(
f"Could not find model weights file in {nemo_file_path}. "
f"Expected files with extensions: .pt, .pth, .ckpt, .bin. "
f"Found files: {[os.path.basename(f) for f in all_files]}"
)
if "model_config" not in model_files:
raise FileNotFoundError(
f"Could not find model config file in {nemo_file_path}. "
f"Expected: model_config.yaml or config.yaml. "
f"Found files: {[os.path.basename(f) for f in all_files]}"
)
return model_files
def write_processor(nemo_config: dict, model_files, output_dir, push_to_repo_id=None):
tokenizer_converted = ParakeetConverter(model_files["tokenizer_model_file"]).converted()
tokenizer_converted_fast = ParakeetTokenizerFast(
tokenizer_object=tokenizer_converted,
clean_up_tokenization_spaces=False,
)
tokenizer_converted_fast.add_tokens(
[AddedToken("<unk>", normalized=False, special=True), AddedToken("<pad>", normalized=False, special=True)]
)
tokenizer_converted_fast.add_special_tokens(
{
"pad_token": AddedToken("<pad>", normalized=False, special=True),
"unk_token": AddedToken("<unk>", normalized=False, special=True),
}
)
feature_extractor_keys_to_ignore = ["_target_", "pad_to", "frame_splicing", "dither", "normalize", "window", "log"]
feature_extractor_config_keys_mapping = {
"sample_rate": "sampling_rate",
"window_size": "win_length",
"window_stride": "hop_length",
"window": "window",
"n_fft": "n_fft",
"log": "log",
"features": "feature_size",
"dither": "dither",
"pad_to": "pad_to",
"pad_value": "padding_value",
"frame_splicing": "frame_splicing",
"preemphasis": "preemphasis",
"hop_length": "hop_length",
}
converted_feature_extractor_config = {}
for key, value in nemo_config["preprocessor"].items():
if key in feature_extractor_keys_to_ignore:
continue
if key in feature_extractor_config_keys_mapping:
if key in ["window_size", "window_stride"]:
value = int(value * nemo_config["preprocessor"]["sample_rate"])
converted_feature_extractor_config[feature_extractor_config_keys_mapping[key]] = value
else:
raise ValueError(f"Key {key} not found in feature_extractor_keys_mapping")
feature_extractor = ParakeetFeatureExtractor(**converted_feature_extractor_config)
processor = ParakeetProcessor(
feature_extractor=feature_extractor,
tokenizer=tokenizer_converted_fast,
)
processor.save_pretrained(output_dir)
if push_to_repo_id:
processor.push_to_hub(push_to_repo_id)
def write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id=None):
encoder_keys_to_ignore = [
"att_context_size",
"causal_downsampling",
"stochastic_depth_start_layer",
"feat_out",
"stochastic_depth_drop_prob",
"_target_",
"ff_expansion_factor",
"untie_biases",
"att_context_style",
"self_attention_model",
"conv_norm_type",
"subsampling",
"stochastic_depth_mode",
"conv_context_size",
"dropout_pre_encoder",
]
enocder_config_keys_mapping = {
"d_model": "hidden_size",
"n_heads": "num_attention_heads",
"n_layers": "num_hidden_layers",
"feat_in": "num_mel_bins",
"conv_kernel_size": "conv_kernel_size",
"subsampling_factor": "subsampling_factor",
"subsampling_conv_channels": "subsampling_conv_channels",
"pos_emb_max_len": "max_position_embeddings",
"dropout": "dropout",
"dropout_emb": "dropout_positions",
"dropout_att": "attention_dropout",
"xscaling": "scale_input",
}
converted_encoder_config = {}
for key, value in nemo_config["encoder"].items():
if key in encoder_keys_to_ignore:
continue
if key in enocder_config_keys_mapping:
converted_encoder_config[enocder_config_keys_mapping[key]] = value
else:
raise ValueError(f"Key {key} not found in enocder_config_keys_mapping")
state_dict = torch.load(model_files["model_weights"], map_location="cpu", weights_only=True)
converted_state_dict = {}
for key, value in state_dict.items():
# Skip preprocessing weights (featurizer components)
if key.endswith("featurizer.window") or key.endswith("featurizer.fb"):
print(f"Skipping preprocessing weight: {key}")
continue
converted_key = convert_key(key, NEMO_TO_HF_WEIGHT_MAPPING)
converted_state_dict[converted_key] = value
if model_type == "ctc":
model_config = ParakeetCTCConfig(
encoder_config=converted_encoder_config,
)
print("Loading the checkpoint in a Parakeet CTC model.")
with torch.device("meta"):
model = ParakeetForCTC(model_config)
model.load_state_dict(converted_state_dict, strict=True, assign=True)
print("Checkpoint loaded successfully.")
del model.config._name_or_path
print("Saving the model.")
model.save_pretrained(output_dir)
if push_to_repo_id:
model.push_to_hub(push_to_repo_id)
del converted_state_dict, model
# Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
ParakeetForCTC.from_pretrained(output_dir, dtype=torch.bfloat16, device_map="auto")
print("Model reloaded successfully.")
else:
raise ValueError(f"Model type {model_type} not supported.")
def main(
hf_repo_id,
output_dir,
model_type,
push_to_repo_id=None,
):
nemo_filename = f"{hf_repo_id.split('/')[-1]}.nemo"
filepath = cached_file(hf_repo_id, nemo_filename)
model_files = extract_nemo_archive(filepath, os.path.dirname(filepath))
nemo_config = yaml.load(open(model_files["model_config"], "r"), Loader=yaml.FullLoader)
write_processor(nemo_config, model_files, output_dir, push_to_repo_id)
write_model(nemo_config, model_files, model_type, output_dir, push_to_repo_id)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--hf_repo_id", required=True, help="Model repo on huggingface.co")
parser.add_argument("--model_type", required=True, choices=["ctc"], help="Model type (`ctc`, `tdt`)")
parser.add_argument("--output_dir", required=True, help="Output directory for HuggingFace model")
parser.add_argument("--push_to_repo_id", help="Repository ID to push the model to on the Hub")
args = parser.parse_args()
main(
args.hf_repo_id,
args.output_dir,
args.model_type,
args.push_to_repo_id,
)

View File

@ -0,0 +1,287 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
import numpy as np
import torch
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, is_librosa_available, logging
from ...utils.import_utils import requires
if is_librosa_available():
import librosa
EPSILON = 1e-5
LOG_ZERO_GUARD_VALUE = 2**-24
logger = logging.get_logger(__name__)
@requires(backends=("torch", "librosa"))
class ParakeetFeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs a Parakeet feature extractor.
This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
most of the main methods. Users should refer to this superclass for more information regarding those methods.
This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time
Fourier Transform` which should match pytorch's `torch.stft` equivalent.
Args:
feature_size (`int`, *optional*, defaults to 80):
The feature dimension of the extracted features.
sampling_rate (`int`, *optional*, defaults to 16000):
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
hop_length (`int`, *optional*, defaults to 160):
Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
n_fft (`int`, *optional*, defaults to 512):
Size of the Fourier transform.
win_length (`int`, *optional*, defaults to 400):
The window length for the STFT computation.
preemphasis (`float`, *optional*, defaults to 0.97):
A preemphasis filter coefficient. 0.0 means no preemphasis filter.
padding_value (`float`, *optional*, defaults to 0.0):
Padding value used to pad the audio. Should correspond to silences.
"""
model_input_names = ["input_features", "attention_mask"]
def __init__(
self,
feature_size=80,
sampling_rate=16000,
hop_length=160,
n_fft=512,
win_length=400,
preemphasis=0.97,
padding_value=0.0,
**kwargs,
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.hop_length = hop_length
self.n_fft = n_fft
self.win_length = win_length
self.preemphasis = preemphasis
# TODO: @eustlb, for now we use librosa to compute the mel filters
# indeed mel_filter_bank uses np.float64 (while librosa uses np.float32), giving numerical differences
# self.mel_filters = mel_filter_bank(
# num_frequency_bins=n_fft // 2 + 1,
# num_mel_filters=feature_size,
# min_frequency=0.0,
# max_frequency=sampling_rate / 2,
# sampling_rate=sampling_rate,
# norm="slaney",
# mel_scale="slaney",
# )
mel_filters = librosa.filters.mel(
sr=sampling_rate, n_fft=n_fft, n_mels=feature_size, fmin=0.0, fmax=sampling_rate / 2, norm="slaney"
)
self.mel_filters = torch.from_numpy(mel_filters).to(torch.float32)
def _torch_extract_fbank_features(self, waveform, device="cpu"):
# spectrogram
window = torch.hann_window(self.win_length, periodic=False, device=device)
stft = torch.stft(
waveform,
self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=window,
return_complex=True,
pad_mode="constant",
)
# Let's math original implementation
# magnitudes = torch.abs(stft) ** 2
magnitudes = torch.view_as_real(stft)
magnitudes = torch.sqrt(magnitudes.pow(2).sum(-1))
magnitudes = magnitudes.pow(2)
# log mel spectrogram
mel_filters = self.mel_filters.to(device)
mel_spec = mel_filters @ magnitudes
mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE)
# (batch_size, num_mel_filters, num_frames) -> (batch_size, num_frames, num_mel_filters)
mel_spec = mel_spec.permute(0, 2, 1)
return mel_spec
def __call__(
self,
raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
truncation: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: Optional[bool] = None,
padding: Optional[str] = "longest",
max_length: Optional[int] = None,
sampling_rate: Optional[int] = None,
do_normalize: Optional[bool] = None,
device: Optional[str] = "cpu",
return_token_timestamps: Optional[bool] = None,
**kwargs,
) -> BatchFeature:
"""
Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
the STFT computation if available, otherwise a slower NumPy based one.
Args:
raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
stereo, i.e. single float per timestep.
truncation (`bool`, *optional*, default to `True`):
Activates truncation to cut input sequences longer than *max_length* to *max_length*.
pad_to_multiple_of (`int`, *optional*, defaults to None):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
return_attention_mask (`bool`, *optional*):
Whether to return the attention mask. If left to the default, will return the attention mask according
to the specific feature_extractor's default.
[What are attention masks?](../glossary#attention-mask)
<Tip>
For Parakeet models, `attention_mask` should always be passed for batched inference, to avoid subtle
bugs.
</Tip>
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors instead of list of python integers. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return Numpy `np.ndarray` objects.
sampling_rate (`int`, *optional*):
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
`sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
pipeline.
padding_value (`float`, *optional*, defaults to 0.0):
The value that is used to fill the padding values / vectors.
do_normalize (`bool`, *optional*, defaults to `False`):
Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
improve the performance of the model.
device (`str`, *optional*, defaults to `'cpu'`):
Specifies the device for computation of the log-mel spectrogram of audio signals in the
`_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
return_token_timestamps (`bool`, *optional*, defaults to `None`):
Deprecated. Use `return_attention_mask` instead from which the number of frames can be inferred.
Whether or not to return the number of frames of the input raw_speech.
These num_frames can be used by the model to compute word level timestamps.
"""
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
f" was sampled with {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
"Failing to do so can result in silent errors that might be hard to debug."
)
# Convert to torch tensor
if isinstance(raw_speech, np.ndarray):
raw_speech = torch.tensor(raw_speech)
elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], np.ndarray):
raw_speech = [torch.tensor(speech) for speech in raw_speech]
is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1
if is_batched_torch and len(raw_speech.shape) > 2:
logger.warning(
f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
"We will take the mean of the channels to convert to mono."
)
raw_speech = raw_speech.mean(-1)
is_batched_sequence = isinstance(raw_speech, (list, tuple))
if is_batched_sequence:
for speech in raw_speech:
if len(speech.shape) > 1:
logger.warning(
f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
"We will take the mean of the channels to convert to mono."
)
speech = speech.mean(-1)
if is_batched_torch or is_batched_sequence:
raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech]
else:
raw_speech = [raw_speech[:, None].to(torch.float32)]
audio_lengths = [len(speech) for speech in raw_speech]
batched_speech = BatchFeature({"input_features": raw_speech, "audio_lengths": audio_lengths})
padded_inputs = self.pad(
batched_speech,
padding=padding,
max_length=max_length,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
)
input_features = padded_inputs.input_features.squeeze(-1)
# preemphasis
if self.preemphasis is not None:
timemask = torch.arange(input_features.shape[1], device=input_features.device).unsqueeze(
0
) < padded_inputs.audio_lengths.unsqueeze(1)
input_features = torch.cat(
[input_features[:, :1], input_features[:, 1:] - self.preemphasis * input_features[:, :-1]], dim=1
)
input_features = input_features.masked_fill(~timemask, 0.0)
input_features = self._torch_extract_fbank_features(input_features, device)
features_lengths = torch.floor_divide(
padded_inputs.audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length
)
attention_mask = torch.arange(input_features.shape[1], device=device)[None, :] < features_lengths[:, None]
# normalize mel features, ignoring padding
mask = attention_mask.unsqueeze(-1)
input_features_masked = input_features * mask
mean = input_features_masked.sum(dim=1) / features_lengths.unsqueeze(-1)
mean = mean.unsqueeze(1)
variance = ((input_features_masked - mean) ** 2 * mask).sum(dim=1) / (features_lengths - 1).unsqueeze(-1)
std = torch.sqrt(variance).unsqueeze(1)
input_features = (input_features - mean) / (std + EPSILON)
input_features *= mask
return BatchFeature(
data={
"input_features": input_features,
"attention_mask": attention_mask,
},
tensor_type=return_tensors,
)
__all__ = ["ParakeetFeatureExtractor"]

View File

@ -0,0 +1,744 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/parakeet/modular_parakeet.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_parakeet.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Callable, Optional, Union
import torch
from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, CausalLMOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.deprecation import deprecate_kwarg
from ...utils.generic import check_model_inputs
from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
class ParakeetEncoderRelPositionalEncoding(nn.Module):
"""Relative positional encoding for Parakeet."""
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: ParakeetEncoderConfig, device=None):
super().__init__()
self.max_position_embeddings = config.max_position_embeddings
base = 10000.0
inv_freq = 1.0 / (
base
** (
torch.arange(0, config.hidden_size, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
/ config.hidden_size
)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, hidden_states: torch.Tensor):
seq_length = hidden_states.shape[1]
if seq_length > self.max_position_embeddings:
raise ValueError(
f"Sequence Length: {seq_length} has to be less or equal than "
f"config.max_position_embeddings {self.max_position_embeddings}."
)
position_ids = torch.arange(seq_length - 1, -seq_length, -1, device=hidden_states.device)
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(hidden_states.shape[0], -1, 1).to(hidden_states.device)
)
position_ids_expanded = position_ids[None, None, :].float()
device_type = (
hidden_states.device.type
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
sin = freqs.sin()
cos = freqs.cos()
# interleave sin and cos
pos_embed = torch.stack([sin, cos], dim=-1)
pos_embed = pos_embed.reshape(*pos_embed.shape[:-2], -1)
return pos_embed.to(dtype=hidden_states.dtype)
class ParakeetEncoderFeedForward(nn.Module):
def __init__(self, config: ParakeetEncoderConfig):
super().__init__()
self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.attention_bias)
self.activation = ACT2FN[config.hidden_act]
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.attention_bias)
self.activation_dropout = config.activation_dropout
def forward(self, hidden_states):
hidden_states = self.activation(self.linear1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.linear2(hidden_states)
return hidden_states
class ParakeetEncoderConvolutionModule(nn.Module):
def __init__(self, config: ParakeetEncoderConfig, module_config=None):
"""
Args:
config (ParakeetEncoderConfig): Configuration for the model.
module_config (dict): Configuration for the module (e.g., encoder or decoder).
"""
super().__init__()
channels = config.hidden_size
# kernel_size should be an odd number for 'SAME' padding
if module_config is None:
# e.g. using `ParakeetEncoderEncoderConfig` in src/transformers/models/parakeet_encoder/configuration_parakeet_encoder.py
kernel_size = config.conv_kernel_size
self.activation = ACT2FN[getattr(config, "hidden_act", "silu")]
else:
kernel_size = module_config["kernel_size"]
self.activation = ACT2FN[module_config.get("activation", "silu")]
self.padding = (kernel_size - 1) // 2
self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True)
self.depthwise_conv = nn.Conv1d(
channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, hidden_states, attention_mask=None):
"""
Compute convolution module.
Args:
hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.
Returns:
`torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
"""
# exchange the temporal dimension and the feature dimension
hidden_states = hidden_states.transpose(1, 2)
# GLU mechanism, (batch_size, 2*channel, dim)
hidden_states = self.pointwise_conv1(hidden_states)
# (batch_size, channel, dim)
hidden_states = nn.functional.glu(hidden_states, dim=1)
# Apply padding mask before convolution
if attention_mask is not None:
all_masked_rows = torch.all(~attention_mask, dim=-1)
hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
# 1D Depthwise Conv
hidden_states = self.depthwise_conv(hidden_states)
hidden_states = self.norm(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.pointwise_conv2(hidden_states)
return hidden_states.transpose(1, 2)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class ParakeetEncoderAttention(nn.Module):
"""Multi-head attention with relative positional encoding. See section 3.3 of https://huggingface.co/papers/1901.02860."""
def __init__(self, config: ParakeetEncoderConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = False
self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
# W_{k,R} projection
self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
# global content bias
self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
# global positional bias
self.bias_v = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
hidden_shape = (batch_size, seq_length, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
query_states_with_bias_u = query_states + self.bias_u.view(
1, self.config.num_attention_heads, 1, self.head_dim
)
query_states_with_bias_v = query_states + self.bias_v.view(
1, self.config.num_attention_heads, 1, self.head_dim
)
relative_key_states = self.relative_k_proj(position_embeddings)
relative_key_states = relative_key_states.view(batch_size, -1, self.config.num_attention_heads, self.head_dim)
# terms (b) and (d)
matrix_bd = query_states_with_bias_v @ relative_key_states.permute(0, 2, 3, 1)
matrix_bd = self._rel_shift(matrix_bd)
matrix_bd = matrix_bd[..., :seq_length]
matrix_bd = matrix_bd * self.scaling
if attention_mask is not None:
# here the original codebase uses -10000.0 rather than float("-inf") and then manual masked fill with 0.0s
# see: https://github.com/NVIDIA-NeMo/NeMo/blob/8cfedd7203462cb251a914e700e5605444277561/nemo/collections/asr/parts/submodules/multi_head_attention.py#L320-L340
# we rather went for a straight-forward approach with float("-inf")
matrix_bd = matrix_bd.masked_fill_(attention_mask.logical_not(), float("-inf"))
# will compute matrix_ac - terms (a) and (c) - and add matrix_bd
attn_output, attn_weights = attention_interface(
self,
query=query_states_with_bias_u,
key=key_states,
value=value_states,
attention_mask=matrix_bd,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
def _rel_shift(self, attention_scores):
"""Relative position shift for Shaw et al. style attention. See appendix B of https://huggingface.co/papers/1901.02860."""
batch_size, num_heads, query_length, position_length = attention_scores.shape
attention_scores = nn.functional.pad(attention_scores, pad=(1, 0))
attention_scores = attention_scores.view(batch_size, num_heads, -1, query_length)
attention_scores = attention_scores[:, :, 1:].view(batch_size, num_heads, query_length, position_length)
return attention_scores
class ParakeetEncoderSubsamplingConv2D(nn.Module):
def __init__(self, config: ParakeetEncoderConfig):
super().__init__()
self.kernel_size = config.subsampling_conv_kernel_size
self.stride = config.subsampling_conv_stride
self.channels = config.subsampling_conv_channels
self.padding = (self.kernel_size - 1) // 2
self.num_layers = int(math.log2(config.subsampling_factor))
# define layers
self.layers = nn.ModuleList()
self.layers.append(
nn.Conv2d(1, self.channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
)
self.layers.append(nn.ReLU())
for i in range(self.num_layers - 1):
# depthwise conv
self.layers.append(
nn.Conv2d(
self.channels,
self.channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
groups=self.channels,
)
)
# pointwise conv
self.layers.append(nn.Conv2d(self.channels, self.channels, kernel_size=1))
# activation
self.layers.append(nn.ReLU())
out_length = config.num_mel_bins // (self.stride**self.num_layers)
self.linear = nn.Linear(config.subsampling_conv_channels * out_length, config.hidden_size, bias=True)
def _get_output_length(self, input_lengths: torch.Tensor, conv_layer: nn.Conv2d):
if hasattr(conv_layer, "stride") and conv_layer.stride != (1, 1):
padding = conv_layer.padding
kernel_size = conv_layer.kernel_size[0]
stride = conv_layer.stride[0]
output_lengths = (input_lengths + padding[0] + padding[1] - kernel_size) // stride + 1
return output_lengths
return input_lengths
def forward(self, input_features: torch.Tensor, attention_mask: torch.Tensor = None):
hidden_states = input_features.unsqueeze(1)
current_lengths = attention_mask.sum(-1) if attention_mask is not None else None
for layer in self.layers:
hidden_states = layer(hidden_states)
# mask the hidden states
if isinstance(layer, nn.Conv2d) and attention_mask is not None:
current_lengths = self._get_output_length(current_lengths, layer)
current_seq_length = hidden_states.shape[2]
channel_mask = (
torch.arange(current_seq_length, device=attention_mask.device) < current_lengths[:, None]
)
hidden_states *= channel_mask[:, None, :, None]
hidden_states = hidden_states.transpose(1, 2).reshape(hidden_states.shape[0], hidden_states.shape[2], -1)
hidden_states = self.linear(hidden_states)
return hidden_states
class ParakeetEncoderBlock(GradientCheckpointingLayer):
def __init__(self, config: ParakeetEncoderConfig, layer_idx: Optional[int] = None):
super().__init__()
self.gradient_checkpointing = False
self.feed_forward1 = ParakeetEncoderFeedForward(config)
self.self_attn = ParakeetEncoderAttention(config, layer_idx)
self.conv = ParakeetEncoderConvolutionModule(config)
self.feed_forward2 = ParakeetEncoderFeedForward(config)
self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size)
self.norm_self_att = nn.LayerNorm(config.hidden_size)
self.norm_conv = nn.LayerNorm(config.hidden_size)
self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size)
self.norm_out = nn.LayerNorm(config.hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
hidden_states = residual + 0.5 * hidden_states # the conformer architecture uses a factor of 0.5
normalized_hidden_states = self.norm_self_att(hidden_states)
attn_output, _ = self.self_attn(
hidden_states=normalized_hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = hidden_states + attn_output
conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
hidden_states = hidden_states + conv_output
ff2_output = self.feed_forward2(self.norm_feed_forward2(hidden_states))
hidden_states = hidden_states + 0.5 * ff2_output # the conformer architecture uses a factor of 0.5
hidden_states = self.norm_out(hidden_states)
return hidden_states
@auto_docstring
class ParakeetPreTrainedModel(PreTrainedModel):
config: ParakeetCTCConfig
base_model_prefix = "model"
main_input_name = "input_features"
supports_gradient_checkpointing = True
_no_split_modules = ["ParakeetEncoderBlock"]
_supports_flat_attention_mask = True
_supports_sdpa = True
_supports_flex_attn = True
# TODO: @eustlb, add support when flash attention supports custom attention bias
_supports_flash_attn = False
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": ParakeetEncoderBlock,
"attentions": ParakeetEncoderAttention,
}
def _init_weights(self, module):
super()._init_weights(module)
if hasattr(self.config, "initializer_range"):
std = self.config.initializer_range
else:
# 0.02 is the standard default value accross the library
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
if isinstance(module, ParakeetEncoderAttention):
# Initialize positional bias parameters
module.bias_u.data.normal_(mean=0.0, std=std)
module.bias_v.data.normal_(mean=0.0, std=std)
def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config
kernel_size = encoder_config.subsampling_conv_kernel_size
stride = encoder_config.subsampling_conv_stride
num_layers = int(math.log2(encoder_config.subsampling_factor))
all_paddings = (kernel_size - 1) // 2 * 2
add_pad = all_paddings - kernel_size
lengths = input_lengths
for _ in range(num_layers):
lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + 1.0
lengths = torch.floor(lengths)
return lengths.to(dtype=torch.int)
def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: Optional[int] = None):
"""
Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
"""
output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
# Use target_length if provided, otherwise use max length in batch
max_length = target_length if target_length is not None else output_lengths.max()
attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None]
return attention_mask
@auto_docstring(
custom_intro="""
The Parakeet Encoder model, based on the [Fast Conformer architecture](https://huggingface.co/papers/2305.05084).
"""
)
class ParakeetEncoder(ParakeetPreTrainedModel):
config: ParakeetEncoderConfig
base_model_prefix = "encoder"
def __init__(self, config: ParakeetEncoderConfig):
super().__init__(config)
self.config = config
self.gradient_checkpointing = False
self.dropout = config.dropout
self.dropout_positions = config.dropout_positions
self.layerdrop = config.layerdrop
self.input_scale = math.sqrt(config.hidden_size) if config.scale_input else 1.0
self.subsampling = ParakeetEncoderSubsamplingConv2D(config)
self.encode_positions = ParakeetEncoderRelPositionalEncoding(config)
self.layers = nn.ModuleList(
[ParakeetEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.post_init()
@auto_docstring
@check_model_inputs
@can_return_tuple
def forward(
self,
input_features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
r"""
Example:
```python
>>> from transformers import AutoProcessor, ParakeetEncoder
>>> from datasets import load_dataset, Audio
>>> model_id = "nvidia/parakeet-ctc-1.1b"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> encoder = ParakeetEncoder.from_pretrained(model_id)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
>>> inputs = processor(ds[0]["audio"]["array"])
>>> encoder_outputs = encoder(**inputs)
>>> print(encoder_outputs.last_hidden_state.shape)
```
"""
hidden_states = self.subsampling(input_features, attention_mask)
hidden_states = hidden_states * self.input_scale
position_embeddings = self.encode_positions(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
position_embeddings = nn.functional.dropout(
position_embeddings, p=self.dropout_positions, training=self.training
)
if attention_mask is not None:
attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
attention_mask = attention_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1)
attention_mask = attention_mask & attention_mask.transpose(1, 2)
attention_mask = attention_mask.unsqueeze(1)
for encoder_layer in self.layers:
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
to_drop = False
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if not to_drop:
hidden_states = encoder_layer(
hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
**kwargs,
)
return BaseModelOutput(last_hidden_state=hidden_states)
@dataclass
class ParakeetGenerateOutput(ModelOutput):
"""
Outputs of Parakeet models.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
"""
sequences: torch.LongTensor
logits: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
@auto_docstring(
custom_intro="""
Parakeet Encoder with a Connectionist Temporal Classification (CTC) head.
"""
)
class ParakeetForCTC(ParakeetPreTrainedModel):
config: ParakeetCTCConfig
def __init__(self, config: ParakeetCTCConfig):
super().__init__(config)
self.encoder = ParakeetEncoder(config.encoder_config)
# Conv rather than linear to be consistent with NeMO decoding layer
self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
self.post_init()
@auto_docstring
@can_return_tuple
def forward(
self,
input_features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutput:
r"""
Example:
```python
>>> from transformers import AutoProcessor, ParakeetForCTC
>>> from datasets import load_dataset, Audio
>>> model_id = "nvidia/parakeet-ctc-1.1b"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> model = ParakeetForCTC.from_pretrained(model_id)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
>>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
>>> outputs = model(**inputs)
>>> print(outputs.loss)
```"""
encoder_outputs = self.encoder(
input_features=input_features,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = encoder_outputs.last_hidden_state
logits = self.ctc_head(hidden_states.transpose(1, 2)).transpose(1, 2)
loss = None
if labels is not None:
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
)
input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = labels != self.config.pad_token_id
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
# ctc_loss doesn't support fp16
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
loss = nn.functional.ctc_loss(
log_probs,
flattened_targets,
input_lengths,
target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
return CausalLMOutput(
loss=loss,
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@torch.no_grad()
def generate(
self,
input_features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_dict_in_generate: bool = False,
**kwargs: Unpack[TransformersKwargs],
) -> Union[ParakeetGenerateOutput, torch.LongTensor]:
r"""
Example:
```python
>>> from transformers import AutoProcessor, ParakeetForCTC
>>> from datasets import load_dataset, Audio
>>> model_id = "nvidia/parakeet-ctc-1.1b"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> model = ParakeetForCTC.from_pretrained(model_id)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
>>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
>>> predicted_ids = model.generate(**inputs)
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
>>> print(transcription)
```
"""
kwargs["return_dict"] = True
outputs: CausalLMOutput = self.forward(
input_features=input_features,
attention_mask=attention_mask,
**kwargs,
)
# greedy decoding
sequences = outputs.logits.argmax(dim=-1)
# mask out padded tokens
if attention_mask is not None:
attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequences.shape[1])
sequences[~attention_mask] = self.config.pad_token_id
if return_dict_in_generate:
return ParakeetGenerateOutput(
sequences=sequences,
logits=outputs.logits,
attentions=outputs.attentions,
hidden_states=outputs.hidden_states,
)
return sequences
__all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"]

View File

@ -0,0 +1,628 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Parakeet model."""
import math
from dataclasses import dataclass
from typing import Callable, Optional, Union
import torch
from torch import nn
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, CausalLMOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ..fastspeech2_conformer.modeling_fastspeech2_conformer import FastSpeech2ConformerConvolutionModule
from ..llama.modeling_llama import LlamaAttention, eager_attention_forward
from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
class ParakeetEncoderRelPositionalEncoding(nn.Module):
"""Relative positional encoding for Parakeet."""
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: ParakeetEncoderConfig, device=None):
super().__init__()
self.max_position_embeddings = config.max_position_embeddings
base = 10000.0
inv_freq = 1.0 / (
base
** (
torch.arange(0, config.hidden_size, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
/ config.hidden_size
)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, hidden_states: torch.Tensor):
seq_length = hidden_states.shape[1]
if seq_length > self.max_position_embeddings:
raise ValueError(
f"Sequence Length: {seq_length} has to be less or equal than "
f"config.max_position_embeddings {self.max_position_embeddings}."
)
position_ids = torch.arange(seq_length - 1, -seq_length, -1, device=hidden_states.device)
inv_freq_expanded = (
self.inv_freq[None, :, None].float().expand(hidden_states.shape[0], -1, 1).to(hidden_states.device)
)
position_ids_expanded = position_ids[None, None, :].float()
device_type = (
hidden_states.device.type
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
sin = freqs.sin()
cos = freqs.cos()
# interleave sin and cos
pos_embed = torch.stack([sin, cos], dim=-1)
pos_embed = pos_embed.reshape(*pos_embed.shape[:-2], -1)
return pos_embed.to(dtype=hidden_states.dtype)
class ParakeetEncoderFeedForward(nn.Module):
def __init__(self, config: ParakeetEncoderConfig):
super().__init__()
self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.attention_bias)
self.activation = ACT2FN[config.hidden_act]
self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.attention_bias)
self.activation_dropout = config.activation_dropout
def forward(self, hidden_states):
hidden_states = self.activation(self.linear1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.linear2(hidden_states)
return hidden_states
class ParakeetEncoderConvolutionModule(FastSpeech2ConformerConvolutionModule):
def __init__(self, config: ParakeetEncoderConfig, module_config=None):
super().__init__(config, module_config)
class ParakeetEncoderAttention(LlamaAttention):
"""Multi-head attention with relative positional encoding. See section 3.3 of https://huggingface.co/papers/1901.02860."""
def __init__(self, config: ParakeetEncoderConfig, layer_idx: int):
super().__init__(config, layer_idx=layer_idx)
self.is_causal = False
# W_{k,R} projection
self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
# global content bias
self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
# global positional bias
self.bias_v = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
hidden_shape = (batch_size, seq_length, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
query_states_with_bias_u = query_states + self.bias_u.view(
1, self.config.num_attention_heads, 1, self.head_dim
)
query_states_with_bias_v = query_states + self.bias_v.view(
1, self.config.num_attention_heads, 1, self.head_dim
)
relative_key_states = self.relative_k_proj(position_embeddings)
relative_key_states = relative_key_states.view(batch_size, -1, self.config.num_attention_heads, self.head_dim)
# terms (b) and (d)
matrix_bd = query_states_with_bias_v @ relative_key_states.permute(0, 2, 3, 1)
matrix_bd = self._rel_shift(matrix_bd)
matrix_bd = matrix_bd[..., :seq_length]
matrix_bd = matrix_bd * self.scaling
if attention_mask is not None:
# here the original codebase uses -10000.0 rather than float("-inf") and then manual masked fill with 0.0s
# see: https://github.com/NVIDIA-NeMo/NeMo/blob/8cfedd7203462cb251a914e700e5605444277561/nemo/collections/asr/parts/submodules/multi_head_attention.py#L320-L340
# we rather went for a straight-forward approach with float("-inf")
matrix_bd = matrix_bd.masked_fill_(attention_mask.logical_not(), float("-inf"))
# will compute matrix_ac - terms (a) and (c) - and add matrix_bd
attn_output, attn_weights = attention_interface(
self,
query=query_states_with_bias_u,
key=key_states,
value=value_states,
attention_mask=matrix_bd,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
def _rel_shift(self, attention_scores):
"""Relative position shift for Shaw et al. style attention. See appendix B of https://huggingface.co/papers/1901.02860."""
batch_size, num_heads, query_length, position_length = attention_scores.shape
attention_scores = nn.functional.pad(attention_scores, pad=(1, 0))
attention_scores = attention_scores.view(batch_size, num_heads, -1, query_length)
attention_scores = attention_scores[:, :, 1:].view(batch_size, num_heads, query_length, position_length)
return attention_scores
class ParakeetEncoderSubsamplingConv2D(nn.Module):
def __init__(self, config: ParakeetEncoderConfig):
super().__init__()
self.kernel_size = config.subsampling_conv_kernel_size
self.stride = config.subsampling_conv_stride
self.channels = config.subsampling_conv_channels
self.padding = (self.kernel_size - 1) // 2
self.num_layers = int(math.log2(config.subsampling_factor))
# define layers
self.layers = nn.ModuleList()
self.layers.append(
nn.Conv2d(1, self.channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
)
self.layers.append(nn.ReLU())
for i in range(self.num_layers - 1):
# depthwise conv
self.layers.append(
nn.Conv2d(
self.channels,
self.channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
groups=self.channels,
)
)
# pointwise conv
self.layers.append(nn.Conv2d(self.channels, self.channels, kernel_size=1))
# activation
self.layers.append(nn.ReLU())
out_length = config.num_mel_bins // (self.stride**self.num_layers)
self.linear = nn.Linear(config.subsampling_conv_channels * out_length, config.hidden_size, bias=True)
def _get_output_length(self, input_lengths: torch.Tensor, conv_layer: nn.Conv2d):
if hasattr(conv_layer, "stride") and conv_layer.stride != (1, 1):
padding = conv_layer.padding
kernel_size = conv_layer.kernel_size[0]
stride = conv_layer.stride[0]
output_lengths = (input_lengths + padding[0] + padding[1] - kernel_size) // stride + 1
return output_lengths
return input_lengths
def forward(self, input_features: torch.Tensor, attention_mask: torch.Tensor = None):
hidden_states = input_features.unsqueeze(1)
current_lengths = attention_mask.sum(-1) if attention_mask is not None else None
for layer in self.layers:
hidden_states = layer(hidden_states)
# mask the hidden states
if isinstance(layer, nn.Conv2d) and attention_mask is not None:
current_lengths = self._get_output_length(current_lengths, layer)
current_seq_length = hidden_states.shape[2]
channel_mask = (
torch.arange(current_seq_length, device=attention_mask.device) < current_lengths[:, None]
)
hidden_states *= channel_mask[:, None, :, None]
hidden_states = hidden_states.transpose(1, 2).reshape(hidden_states.shape[0], hidden_states.shape[2], -1)
hidden_states = self.linear(hidden_states)
return hidden_states
class ParakeetEncoderBlock(GradientCheckpointingLayer):
def __init__(self, config: ParakeetEncoderConfig, layer_idx: Optional[int] = None):
super().__init__()
self.gradient_checkpointing = False
self.feed_forward1 = ParakeetEncoderFeedForward(config)
self.self_attn = ParakeetEncoderAttention(config, layer_idx)
self.conv = ParakeetEncoderConvolutionModule(config)
self.feed_forward2 = ParakeetEncoderFeedForward(config)
self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size)
self.norm_self_att = nn.LayerNorm(config.hidden_size)
self.norm_conv = nn.LayerNorm(config.hidden_size)
self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size)
self.norm_out = nn.LayerNorm(config.hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
hidden_states = residual + 0.5 * hidden_states # the conformer architecture uses a factor of 0.5
normalized_hidden_states = self.norm_self_att(hidden_states)
attn_output, _ = self.self_attn(
hidden_states=normalized_hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = hidden_states + attn_output
conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
hidden_states = hidden_states + conv_output
ff2_output = self.feed_forward2(self.norm_feed_forward2(hidden_states))
hidden_states = hidden_states + 0.5 * ff2_output # the conformer architecture uses a factor of 0.5
hidden_states = self.norm_out(hidden_states)
return hidden_states
@auto_docstring
class ParakeetPreTrainedModel(PreTrainedModel):
config: ParakeetCTCConfig
base_model_prefix = "model"
main_input_name = "input_features"
supports_gradient_checkpointing = True
_no_split_modules = ["ParakeetEncoderBlock"]
_supports_flat_attention_mask = True
_supports_sdpa = True
_supports_flex_attn = True
# TODO: @eustlb, add support when flash attention supports custom attention bias
_supports_flash_attn = False
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": ParakeetEncoderBlock,
"attentions": ParakeetEncoderAttention,
}
def _init_weights(self, module):
super()._init_weights(module)
if hasattr(self.config, "initializer_range"):
std = self.config.initializer_range
else:
# 0.02 is the standard default value accross the library
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
if isinstance(module, ParakeetEncoderAttention):
# Initialize positional bias parameters
module.bias_u.data.normal_(mean=0.0, std=std)
module.bias_v.data.normal_(mean=0.0, std=std)
def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config
kernel_size = encoder_config.subsampling_conv_kernel_size
stride = encoder_config.subsampling_conv_stride
num_layers = int(math.log2(encoder_config.subsampling_factor))
all_paddings = (kernel_size - 1) // 2 * 2
add_pad = all_paddings - kernel_size
lengths = input_lengths
for _ in range(num_layers):
lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + 1.0
lengths = torch.floor(lengths)
return lengths.to(dtype=torch.int)
def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: Optional[int] = None):
"""
Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
"""
output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
# Use target_length if provided, otherwise use max length in batch
max_length = target_length if target_length is not None else output_lengths.max()
attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None]
return attention_mask
@auto_docstring(
custom_intro="""
The Parakeet Encoder model, based on the [Fast Conformer architecture](https://huggingface.co/papers/2305.05084).
"""
)
class ParakeetEncoder(ParakeetPreTrainedModel):
config: ParakeetEncoderConfig
base_model_prefix = "encoder"
def __init__(self, config: ParakeetEncoderConfig):
super().__init__(config)
self.config = config
self.gradient_checkpointing = False
self.dropout = config.dropout
self.dropout_positions = config.dropout_positions
self.layerdrop = config.layerdrop
self.input_scale = math.sqrt(config.hidden_size) if config.scale_input else 1.0
self.subsampling = ParakeetEncoderSubsamplingConv2D(config)
self.encode_positions = ParakeetEncoderRelPositionalEncoding(config)
self.layers = nn.ModuleList(
[ParakeetEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.post_init()
@auto_docstring
@check_model_inputs
@can_return_tuple
def forward(
self,
input_features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutput:
r"""
Example:
```python
>>> from transformers import AutoProcessor, ParakeetEncoder
>>> from datasets import load_dataset, Audio
>>> model_id = "nvidia/parakeet-ctc-1.1b"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> encoder = ParakeetEncoder.from_pretrained(model_id)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
>>> inputs = processor(ds[0]["audio"]["array"])
>>> encoder_outputs = encoder(**inputs)
>>> print(encoder_outputs.last_hidden_state.shape)
```
"""
hidden_states = self.subsampling(input_features, attention_mask)
hidden_states = hidden_states * self.input_scale
position_embeddings = self.encode_positions(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
position_embeddings = nn.functional.dropout(
position_embeddings, p=self.dropout_positions, training=self.training
)
if attention_mask is not None:
attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
attention_mask = attention_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1)
attention_mask = attention_mask & attention_mask.transpose(1, 2)
attention_mask = attention_mask.unsqueeze(1)
for encoder_layer in self.layers:
# add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
to_drop = False
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if not to_drop:
hidden_states = encoder_layer(
hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
**kwargs,
)
return BaseModelOutput(last_hidden_state=hidden_states)
@dataclass
class ParakeetGenerateOutput(ModelOutput):
"""
Outputs of Parakeet models.
Args:
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
"""
sequences: torch.LongTensor
logits: Optional[tuple[torch.FloatTensor]] = None
attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
@auto_docstring(
custom_intro="""
Parakeet Encoder with a Connectionist Temporal Classification (CTC) head.
"""
)
class ParakeetForCTC(ParakeetPreTrainedModel):
config: ParakeetCTCConfig
def __init__(self, config: ParakeetCTCConfig):
super().__init__(config)
self.encoder = ParakeetEncoder(config.encoder_config)
# Conv rather than linear to be consistent with NeMO decoding layer
self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
self.post_init()
@auto_docstring
@can_return_tuple
def forward(
self,
input_features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutput:
r"""
Example:
```python
>>> from transformers import AutoProcessor, ParakeetForCTC
>>> from datasets import load_dataset, Audio
>>> model_id = "nvidia/parakeet-ctc-1.1b"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> model = ParakeetForCTC.from_pretrained(model_id)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
>>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
>>> outputs = model(**inputs)
>>> print(outputs.loss)
```"""
encoder_outputs = self.encoder(
input_features=input_features,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = encoder_outputs.last_hidden_state
logits = self.ctc_head(hidden_states.transpose(1, 2)).transpose(1, 2)
loss = None
if labels is not None:
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
)
input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = labels != self.config.pad_token_id
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
# ctc_loss doesn't support fp16
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
loss = nn.functional.ctc_loss(
log_probs,
flattened_targets,
input_lengths,
target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
return CausalLMOutput(
loss=loss,
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@torch.no_grad()
def generate(
self,
input_features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_dict_in_generate: bool = False,
**kwargs: Unpack[TransformersKwargs],
) -> Union[ParakeetGenerateOutput, torch.LongTensor]:
r"""
Example:
```python
>>> from transformers import AutoProcessor, ParakeetForCTC
>>> from datasets import load_dataset, Audio
>>> model_id = "nvidia/parakeet-ctc-1.1b"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> model = ParakeetForCTC.from_pretrained(model_id)
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
>>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
>>> predicted_ids = model.generate(**inputs)
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
>>> print(transcription)
```
"""
kwargs["return_dict"] = True
outputs: CausalLMOutput = self.forward(
input_features=input_features,
attention_mask=attention_mask,
**kwargs,
)
# greedy decoding
sequences = outputs.logits.argmax(dim=-1)
# mask out padded tokens
if attention_mask is not None:
attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequences.shape[1])
sequences[~attention_mask] = self.config.pad_token_id
if return_dict_in_generate:
return ParakeetGenerateOutput(
sequences=sequences,
logits=outputs.logits,
attentions=outputs.attentions,
hidden_states=outputs.hidden_states,
)
return sequences
__all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"]

View File

@ -0,0 +1,87 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from ...audio_utils import AudioInput, make_list_of_audio
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging
logger = logging.get_logger(__name__)
class ParakeetProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"audio_kwargs": {
"sampling_rate": 16000,
"padding": "longest",
},
"text_kwargs": {
"padding": True,
"padding_side": "right",
"add_special_tokens": False,
},
"common_kwargs": {"return_tensors": "pt"},
}
class ParakeetProcessor(ProcessorMixin):
attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "ParakeetFeatureExtractor"
tokenizer_class = "ParakeetTokenizerFast"
def __call__(
self,
audio: AudioInput,
text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput], None] = None,
sampling_rate: Optional[int] = None,
**kwargs: Unpack[ParakeetProcessorKwargs],
):
audio = make_list_of_audio(audio)
output_kwargs = self._merge_kwargs(
ParakeetProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if sampling_rate is None:
logger.warning_once(
f"You've provided audio without specifying the sampling rate. It will be assumed to be {output_kwargs['audio_kwargs']['sampling_rate']}, which can result in silent errors."
)
elif sampling_rate != output_kwargs["audio_kwargs"]["sampling_rate"]:
raise ValueError(
f"The sampling rate of the audio ({sampling_rate}) does not match the sampling rate of the processor ({output_kwargs['audio_kwargs']['sampling_rate']}). Please provide resampled the audio to the expected sampling rate."
)
if audio is not None:
inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
if text is not None:
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
if text is None:
return inputs
else:
inputs["labels"] = encodings["input_ids"]
return inputs
@property
def model_input_names(self):
feature_extractor_input_names = self.feature_extractor.model_input_names
return feature_extractor_input_names + ["labels"]
__all__ = ["ParakeetProcessor"]

View File

@ -0,0 +1,54 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Optional, Union
from ...tokenization_utils_fast import PreTrainedTokenizerFast
class ParakeetTokenizerFast(PreTrainedTokenizerFast):
"""
Inherits all methods from [`PreTrainedTokenizerFast`]. Users should refer to this superclass for more information regarding those methods,
except for `_decode` which is overridden to adapt it to CTC decoding:
1. Group consecutive tokens
2. Filter out the blank token
"""
def _decode(
self,
token_ids: Union[int, list[int]],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: Optional[bool] = None,
group_tokens: bool = True,
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
if group_tokens:
token_ids = [token_group[0] for token_group in itertools.groupby(token_ids)]
# for CTC we filter out the blank token, which is the pad token
token_ids = [token for token in token_ids if token != self.pad_token_id]
return super()._decode(
token_ids=token_ids,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
__all__ = ["ParakeetTokenizerFast"]

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1 @@
{"transcriptions": ["mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"], "scores": [-0.08922013640403748], "token_ids": [[1024, 1024, 1024, 1024, 1024, 1024, 19, 37, 132, 1024, 1024, 264, 128, 1024, 1024, 1024, 132, 1024, 58, 1024, 5, 645, 1024, 1000, 82, 52, 1024, 34, 1024, 5, 19, 68, 1007, 52, 1024, 235, 1024, 388, 1024, 27, 1024, 25, 1024, 56, 1024, 103, 1024, 1024, 727, 112, 1024, 22, 1024, 56, 1006, 1009, 405, 1024, 1024, 217, 1024, 1024, 95, 1003, 1024, 133, 1006, 1024, 1024, 1024, 1024, 1024, 1024, 1024]]}

View File

View File

@ -0,0 +1,197 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the Parakeet feature extraction."""
import itertools
import random
import unittest
import numpy as np
from transformers import ParakeetFeatureExtractor
from transformers.testing_utils import require_torch
from transformers.utils import is_datasets_available, is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
if is_torch_available():
import torch
if is_datasets_available():
from datasets import load_dataset
global_rng = random.Random()
def floats_list(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = global_rng
values = []
for batch_idx in range(shape[0]):
values.append([])
for _ in range(shape[1]):
values[-1].append(rng.random() * scale)
return values
class ParakeetFeatureExtractionTester:
def __init__(
self,
parent,
batch_size=7,
min_seq_length=400,
max_seq_length=2000,
feature_size=80,
hop_length=160,
win_length=400,
n_fft=512,
sampling_rate=16000,
padding_value=0.0,
):
self.parent = parent
self.batch_size = batch_size
self.min_seq_length = min_seq_length
self.max_seq_length = max_seq_length
self.seq_length_diff = (self.max_seq_length - self.min_seq_length) // (self.batch_size - 1)
self.feature_size = feature_size
self.hop_length = hop_length
self.win_length = win_length
self.n_fft = n_fft
self.sampling_rate = sampling_rate
self.padding_value = padding_value
def prepare_feat_extract_dict(self):
return {
"feature_size": self.feature_size,
"hop_length": self.hop_length,
"win_length": self.win_length,
"n_fft": self.n_fft,
"sampling_rate": self.sampling_rate,
"padding_value": self.padding_value,
}
# Copied from tests.models.whisper.test_feature_extraction_whisper.WhisperFeatureExtractionTester.prepare_inputs_for_common
def prepare_inputs_for_common(self, equal_length=False, numpify=False):
def _flatten(list_of_lists):
return list(itertools.chain(*list_of_lists))
if equal_length:
speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
else:
# make sure that inputs increase in size
speech_inputs = [
floats_list((x, self.feature_size))
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
]
if numpify:
speech_inputs = [np.asarray(x) for x in speech_inputs]
return speech_inputs
class ParakeetFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = ParakeetFeatureExtractor
def setUp(self):
self.feat_extract_tester = ParakeetFeatureExtractionTester(self)
def _load_datasamples(self, num_samples):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
# automatic decoding with librispeech
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
@require_torch
def test_torch_integration(self):
"""
reproducer: https://gist.github.com/eustlb/c4a0999e54466b7e8d8b040d8e0900df
"""
# fmt: off
EXPECTED_INPUT_FEATURES = torch.tensor(
[
0.60935932, 1.18187428, 1.29877627, 1.36461377, 1.09311509, 1.39821815,
1.63753450, 1.37100816, 1.26510608, 1.70332706, 1.69067430, 1.28770995,
1.52999651, 1.77962756, 1.71420062, 1.21944094, 1.30884087, 1.44343364,
1.17694926, 1.42690814, 1.78877723, 1.68655288, 1.27155364, 1.66103351,
1.75820673, 1.41575801, 1.40622294, 1.70603478, 1.63117850, 1.13353217,
]
)
# fmt: on
input_speech = self._load_datasamples(1)
feature_extractor = ParakeetFeatureExtractor()
inputs = feature_extractor(input_speech, return_tensors="pt")
self.assertEqual(inputs.input_features.shape, (1, 586, 80))
torch.testing.assert_close(inputs.input_features[0, 100, :30], EXPECTED_INPUT_FEATURES, atol=1e-4, rtol=1e-4)
self.assertEqual(inputs.attention_mask.shape, (1, 586))
# last frame should be masked
self.assertEqual(inputs.attention_mask.sum(), 585)
@require_torch
def test_torch_integration_batch(self):
"""
reproducer: https://gist.github.com/eustlb/c4a0999e54466b7e8d8b040d8e0900df
"""
# fmt: off
EXPECTED_INPUT_FEATURES = torch.tensor(
[
[ 0.60935932, 1.18187428, 1.29877627, 1.36461377, 1.09311533,
1.39821827, 1.63753450, 1.37100816, 1.26510608, 1.70332706,
1.69067478, 1.28770995, 1.52999651, 1.77962780, 1.71420062,
1.21944094, 1.30884087, 1.44343400, 1.17694926, 1.42690814,
1.78877664, 1.68655288, 1.27155364, 1.66103351, 1.75820673,
1.41575801, 1.40622294, 1.70603478, 1.63117862, 1.13353217],
[ 0.58339858, 0.54317272, 0.46222782, 0.34154415, 0.17806509,
0.32182255, 0.28909618, 0.02141305, -0.09710173, -0.35818669,
-0.48172510, -0.52942866, -0.58029658, -0.70519227, -0.67929971,
-0.54698551, -0.28611183, -0.24780270, -0.31363955, -0.41913241,
-0.32394424, -0.44897896, -0.68657434, -0.62047797, -0.46886450,
-0.65987164, -1.02435589, -0.58527517, -0.56095684, -0.73582536],
[-0.91937613, -0.97933632, -1.06843162, -1.02642107, -0.94232899,
-0.83840621, -0.82306921, -0.45763230, -0.45182887, -0.75917768,
-0.42541453, -0.28512970, -0.39637473, -0.66478080, -0.68004298,
-0.49690303, -0.31799242, -0.12917191, 0.13149273, 0.10163058,
-0.40041649, 0.05001565, 0.23906317, 0.28816083, 0.14308788,
-0.29588422, -0.05428466, 0.14418560, 0.28865972, -0.12138986],
[ 0.73217624, 0.84484011, 0.79323846, 0.66315967, 0.41556871,
0.88633078, 0.90718138, 0.91268104, 1.15920067, 1.26141894,
1.10222173, 0.92990804, 0.96352047, 0.88142169, 0.56635213,
0.71491158, 0.81301254, 0.67301887, 0.74780160, 0.64429688,
0.22885245, 0.47035533, 0.46498337, 0.17544533, 0.44458991,
0.79245001, 0.57207537, 0.85768145, 1.00491571, 0.93360955],
[ 1.40496337, 1.32492661, 1.16519547, 0.98379827, 0.77614164,
0.95871657, 0.81910741, 1.23010278, 1.33011520, 1.16538525,
1.28319681, 1.45041633, 1.33421600, 0.91677380, 0.67107433,
0.52890682, 0.82009870, 1.15821445, 1.15343642, 1.10958862,
1.44962490, 1.44485891, 1.46043479, 1.90800595, 1.95863307,
1.63670933, 1.49021459, 1.18701911, 0.74906683, 0.84700620]
]
)
# fmt: on
input_speech = self._load_datasamples(5)
feature_extractor = ParakeetFeatureExtractor()
inputs = feature_extractor(input_speech, return_tensors="pt")
self.assertEqual(inputs.input_features.shape, (5, 2941, 80))
torch.testing.assert_close(inputs.input_features[:, 100, :30], EXPECTED_INPUT_FEATURES, atol=1e-4, rtol=1e-4)
self.assertEqual(inputs.attention_mask.shape, (5, 2941))
self.assertTrue(inputs.attention_mask.sum(dim=-1).tolist(), [585, 481, 1248, 990, 2940])

View File

@ -0,0 +1,380 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Parakeet model."""
import json
import tempfile
import unittest
from pathlib import Path
from transformers import is_datasets_available, is_torch_available
from transformers.testing_utils import cleanup, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_datasets_available():
from datasets import Audio, load_dataset
if is_torch_available():
import torch
from transformers import (
AutoProcessor,
ParakeetCTCConfig,
ParakeetEncoder,
ParakeetEncoderConfig,
ParakeetForCTC,
)
class ParakeetEncoderModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=1024,
is_training=True,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=256,
hidden_act="silu",
dropout=0, # so gradient checkpointing doesn't fail
conv_kernel_size=9,
subsampling_factor=8,
subsampling_conv_channels=32,
use_bias=True,
num_mel_bins=80,
scale_input=True,
):
# testing suite parameters
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.num_mel_bins = num_mel_bins
self.is_training = is_training
# config parameters
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.dropout = dropout
self.conv_kernel_size = conv_kernel_size
self.subsampling_factor = subsampling_factor
self.subsampling_conv_channels = subsampling_conv_channels
self.use_bias = use_bias
self.num_mel_bins = num_mel_bins
self.scale_input = scale_input
# Calculate output sequence length after subsampling
self.output_seq_length = seq_length // subsampling_factor
self.encoder_seq_length = self.output_seq_length
self.key_length = self.output_seq_length
def prepare_config_and_inputs(self):
input_features = floats_tensor([self.batch_size, self.seq_length, self.num_mel_bins])
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
config = self.get_config()
return config, input_features, attention_mask
def get_config(self):
return ParakeetEncoderConfig(
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
dropout=self.dropout,
dropout_positions=self.dropout,
layerdrop=self.dropout,
activation_dropout=self.dropout,
attention_dropout=self.dropout,
conv_kernel_size=self.conv_kernel_size,
subsampling_factor=self.subsampling_factor,
subsampling_conv_channels=self.subsampling_conv_channels,
use_bias=self.use_bias,
num_mel_bins=self.num_mel_bins,
scale_input=self.scale_input,
)
def create_and_check_model(self, config, input_features, attention_mask):
model = ParakeetEncoder(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(input_features, attention_mask=attention_mask)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, config.hidden_size)
)
def prepare_config_and_inputs_for_common(self):
config, input_features, attention_mask = self.prepare_config_and_inputs()
inputs_dict = {
"input_features": input_features,
"attention_mask": attention_mask,
}
return config, inputs_dict
def check_ctc_loss(self, config, input_values, *args):
model = ParakeetForCTC(config=config)
model.to(torch_device)
# make sure that dropout is disabled
model.eval()
input_values = input_values[:3]
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
# pad input
for i in range(len(input_lengths)):
input_values[i, input_lengths[i] :] = 0.0
attention_mask[i, input_lengths[i] :] = 0
model.config.ctc_loss_reduction = "sum"
sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
model.config.ctc_loss_reduction = "mean"
mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
self.parent.assertTrue(isinstance(sum_loss, float))
self.parent.assertTrue(isinstance(mean_loss, float))
@require_torch
class ParakeetEncoderModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (ParakeetEncoder,) if is_torch_available() else ()
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
def setUp(self):
self.model_tester = ParakeetEncoderModelTester(self)
self.config_tester = ConfigTester(self, config_class=ParakeetEncoderConfig, has_text_modality=False)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="ParakeetEncoder does not use inputs_embeds")
def test_model_get_set_embeddings(self):
pass
class ParakeetForCTCModelTester:
def __init__(self, parent, encoder_kwargs=None, is_training=True, vocab_size=128, pad_token_id=0):
if encoder_kwargs is None:
encoder_kwargs = {}
self.parent = parent
self.encoder_model_tester = ParakeetEncoderModelTester(parent, **encoder_kwargs)
self.is_training = is_training
self.batch_size = self.encoder_model_tester.batch_size
self.output_seq_length = self.encoder_model_tester.output_seq_length
self.num_hidden_layers = self.encoder_model_tester.num_hidden_layers
self.seq_length = vocab_size
self.hidden_size = self.encoder_model_tester.hidden_size
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
def prepare_config_and_inputs(self):
_, input_features, attention_mask = self.encoder_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, input_features, attention_mask
def get_config(self):
return ParakeetCTCConfig.from_encoder_config(
encoder_config=self.encoder_model_tester.get_config(),
vocab_size=self.vocab_size,
pad_token_id=self.pad_token_id,
)
def create_and_check_model(self, config, input_features, attention_mask):
model = ParakeetForCTC(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(input_features, attention_mask=attention_mask)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.output_seq_length, self.vocab_size))
def prepare_config_and_inputs_for_common(self):
config, input_features, attention_mask = self.prepare_config_and_inputs()
inputs_dict = {
"input_features": input_features,
"attention_mask": attention_mask,
}
return config, inputs_dict
def test_ctc_loss_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.encoder_model_tester.check_ctc_loss(*config_and_inputs)
@require_torch
class ParakeetForCTCModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (ParakeetForCTC,) if is_torch_available() else ()
pipeline_model_mapping = (
{
"feature-extraction": ParakeetEncoder,
"automatic-speech-recognition": ParakeetForCTC,
}
if is_torch_available()
else {}
)
test_attention_outputs = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_torch_exportable = True
_is_composite = True
def setUp(self):
self.model_tester = ParakeetForCTCModelTester(self)
self.config_tester = ConfigTester(self, config_class=ParakeetCTCConfig)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="ParakeetEncoder does not use inputs_embeds")
def test_model_get_set_embeddings(self):
pass
# Original function assumes vision+text model, so overwrite since Parakeet is audio+text
# Below is modified from `tests/models/granite_speech/test_modeling_granite_speech.py`
def test_sdpa_can_dispatch_composite_models(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
if not self._is_composite:
self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA")
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_sdpa = model_class.from_pretrained(tmpdirname)
model_sdpa = model_sdpa.eval().to(torch_device)
model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager")
model_eager = model_eager.eval().to(torch_device)
self.assertTrue(model_eager.config._attn_implementation == "eager")
for name, submodule in model_eager.named_modules():
class_name = submodule.__class__.__name__
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers")
@require_torch
class ParakeetForCTCIntegrationTest(unittest.TestCase):
_dataset = None
@classmethod
def setUp(cls):
cls.checkpoint_name = "nvidia/parakeet-ctc-1.1b"
cls.dtype = torch.bfloat16
cls.processor = AutoProcessor.from_pretrained("nvidia/parakeet-ctc-1.1b")
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@classmethod
def _load_dataset(cls):
# Lazy loading of the dataset. Because it is a class method, it will only be loaded once per pytest process.
if cls._dataset is None:
cls._dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
cls._dataset = cls._dataset.cast_column(
"audio", Audio(sampling_rate=cls.processor.feature_extractor.sampling_rate)
)
def _load_datasamples(self, num_samples):
self._load_dataset()
ds = self._dataset
speech_samples = ds.sort("id")[:num_samples]["audio"]
return [x["array"] for x in speech_samples]
@slow
def test_1b_model_integration(self):
"""
bezzam reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_single-py
eustlb reproducer: https://gist.github.com/eustlb/6e9e3aa85de3f7c340ec3c36e65f2fe6
"""
RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_single.json"
with open(RESULTS_PATH, "r") as f:
raw_data = json.load(f)
EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"])
EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"]
samples = self._load_datasamples(1)
model = ParakeetForCTC.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device)
model.eval()
model.to(torch_device)
# -- apply
inputs = self.processor(samples)
inputs.to(torch_device, dtype=self.dtype)
predicted_ids = model.generate(**inputs)
torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS)
predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS)
@slow
def test_1b_model_integration_batched(self):
"""
bezzam reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/6382bdabfc64bb2541ca9f77deb7678d#file-reproducer_batched-py
eustlb reproducer: https://gist.github.com/eustlb/575b5da58de34a70116a1955b1183596
"""
RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/parakeet/expected_results_batch.json"
with open(RESULTS_PATH, "r") as f:
raw_data = json.load(f)
EXPECTED_TOKEN_IDS = torch.tensor(raw_data["token_ids"])
EXPECTED_TRANSCRIPTIONS = raw_data["transcriptions"]
samples = self._load_datasamples(5)
model = ParakeetForCTC.from_pretrained(self.checkpoint_name, torch_dtype=self.dtype, device_map=torch_device)
model.eval()
model.to(torch_device)
# -- apply
inputs = self.processor(samples)
inputs.to(torch_device, dtype=self.dtype)
predicted_ids = model.generate(**inputs)
torch.testing.assert_close(predicted_ids.cpu(), EXPECTED_TOKEN_IDS)
predicted_transcripts = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
self.assertListEqual(predicted_transcripts, EXPECTED_TRANSCRIPTIONS)

View File

@ -0,0 +1,49 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
from transformers import AutoProcessor, ParakeetProcessor
from transformers.testing_utils import require_torch, require_torchaudio
from ...test_processing_common import ProcessorTesterMixin
@require_torch
@require_torchaudio
class ParakeetProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = ParakeetProcessor
text_input_name = "labels"
@classmethod
def setUpClass(cls):
cls.tmpdirname = tempfile.mkdtemp()
cls.checkpoint = "nvidia/parakeet-ctc-1.1b"
processor = ParakeetProcessor.from_pretrained(cls.checkpoint)
processor.save_pretrained(cls.tmpdirname)
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_feature_extractor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).feature_extractor
def get_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname, ignore_errors=True)

View File

@ -0,0 +1,53 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the ParakeetCTC tokenizer."""
import unittest
from transformers.models.parakeet import ParakeetTokenizerFast
from ...test_tokenization_common import TokenizerTesterMixin
class ParakeetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
slow_tokenizer_class = None
rust_tokenizer_class = ParakeetTokenizerFast
tokenizer_class = ParakeetTokenizerFast
test_slow_tokenizer = False
test_rust_tokenizer = True
from_pretrained_id = "nvidia/parakeet-ctc-1.1b"
@classmethod
def setUpClass(cls):
super().setUpClass()
tokenizer = ParakeetTokenizerFast.from_pretrained("nvidia/parakeet-ctc-1.1b")
tokenizer.save_pretrained(cls.tmpdirname)
@unittest.skip(
reason="This test does not apply to ParakeetTokenizerFast. More details in the test docstring itself."
)
def test_added_tokens_do_lower_case(self):
"""
Precompiled normalization from sentencepiece is `nmt_nfkc_cf` that includes lowercasing. Yet, ParakeetTokenizerFast does not have a do_lower_case attribute.
This result in the test failing.
"""
pass
@unittest.skip(reason="This needs a slow tokenizer. Parakeet does not have one!")
def test_encode_decode_with_spaces(self):
return
@unittest.skip(reason="ParakeetTokenizerFast doesn't have tokenizer_file in its signature.")
def test_rust_tokenizer_signature(self):
pass

View File

@ -55,6 +55,7 @@ from transformers.models.auto.modeling_auto import (
MODEL_FOR_BACKBONE_MAPPING_NAMES,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_CTC_MAPPING_NAMES,
MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
@ -657,6 +658,7 @@ class ModelTesterMixin:
*get_values(MODEL_FOR_MASKED_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES),
*get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES),
*get_values(MODEL_FOR_CTC_MAPPING_NAMES),
]:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device

View File

@ -315,6 +315,7 @@ OBJECTS_TO_IGNORE = {
"OpenAIGPTTokenizerFast",
"OpenLlamaConfig",
"PLBartConfig",
"ParakeetCTCConfig",
"PegasusConfig",
"PegasusTokenizer",
"PegasusTokenizerFast",