mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
fix asr ut failures (#41332)
Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
This commit is contained in:
@ -1915,7 +1915,7 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
|
||||
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens)
|
||||
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens.to(scores.device))
|
||||
scores = torch.where(suppress_token_mask, -float("inf"), scores)
|
||||
return scores
|
||||
|
||||
|
@ -1104,7 +1104,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
def test_speculative_decoding_whisper_non_distil(self):
|
||||
# Load data:
|
||||
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")
|
||||
sample = dataset[0]["audio"]
|
||||
sample = dataset[0]["audio"].get_all_samples().data
|
||||
|
||||
# Load model:
|
||||
model_id = "openai/whisper-large-v2"
|
||||
@ -1133,8 +1133,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
num_beams=1,
|
||||
)
|
||||
|
||||
transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
|
||||
transcription_ass = pipe(sample)["text"]
|
||||
transcription_ass = pipe(sample.clone().detach(), generate_kwargs={"assistant_model": assistant_model})["text"]
|
||||
transcription_non_ass = pipe(sample)["text"]
|
||||
|
||||
self.assertEqual(transcription_ass, transcription_non_ass)
|
||||
self.assertEqual(
|
||||
@ -1422,13 +1422,13 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
|
||||
sample = dataset[0]["audio"]
|
||||
sample = dataset[0]["audio"].get_all_samples().data
|
||||
|
||||
# prompt the model to misspell "Mr Quilter" as "Mr Quillter"
|
||||
whisper_prompt = "Mr. Quillter."
|
||||
prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt").to(torch_device)
|
||||
|
||||
unprompted_result = pipe(sample.copy())["text"]
|
||||
unprompted_result = pipe(sample.clone().detach())["text"]
|
||||
prompted_result = pipe(sample, generate_kwargs={"prompt_ids": prompt_ids})["text"]
|
||||
|
||||
# fmt: off
|
||||
|
Reference in New Issue
Block a user