Compare commits

...

1 Commits

Author SHA1 Message Date
6c9f50deec fix audio pipeline with torchcodec input 2025-07-09 15:55:27 +02:00
3 changed files with 14 additions and 18 deletions

View File

@ -174,14 +174,7 @@ class AudioClassificationPipeline(Pipeline):
if isinstance(inputs, bytes):
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
if is_torch_available():
import torch
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu().numpy()
if is_torchcodec_available():
import torch
import torchcodec
if isinstance(inputs, torchcodec.decoders.AudioDecoder):
@ -224,10 +217,14 @@ class AudioClassificationPipeline(Pipeline):
self.feature_extractor.sampling_rate,
).numpy()
if is_torch_available():
import torch
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu().numpy()
if not isinstance(inputs, np.ndarray):
raise TypeError("We expect a numpy ndarray or torch tensor as input")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AudioClassificationPipeline")
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"

View File

@ -365,12 +365,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
stride = None
extra = {}
if is_torch_available():
import torch
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu().numpy()
if is_torchcodec_available():
import torchcodec
@ -425,10 +419,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# can add extra data in the inputs, so we need to keep track
# of the original length in the stride so we can cut properly.
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
if is_torch_available():
import torch
if isinstance(inputs, torch.Tensor):
inputs = inputs.cpu().numpy()
if not isinstance(inputs, np.ndarray):
raise TypeError(f"We expect a numpy ndarray or torch tensor as input, got `{type(inputs)}`")
if len(inputs.shape) != 1:
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
if chunk_length_s:
if stride_length_s is None:

View File

@ -1148,7 +1148,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
num_beams=1,
)
transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
transcription_non_ass = pipe(sample, generate_kwargs={"assistant_model": assistant_model})["text"]
transcription_ass = pipe(sample)["text"]
self.assertEqual(transcription_ass, transcription_non_ass)