fix asr ut failures (#41332)

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix
2025-10-06 10:12:19 -07:00
committed by GitHub
parent 57e82745f9
commit 73f8c4b8ad
2 changed files with 6 additions and 6 deletions

View File

@ -1915,7 +1915,7 @@ class SuppressTokensLogitsProcessor(LogitsProcessor):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
vocab_tensor = torch.arange(scores.shape[-1], device=scores.device)
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens)
suppress_token_mask = isin_mps_friendly(vocab_tensor, self.suppress_tokens.to(scores.device))
scores = torch.where(suppress_token_mask, -float("inf"), scores)
return scores

View File

@ -1104,7 +1104,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
def test_speculative_decoding_whisper_non_distil(self):
# Load data:
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation[:1]")
sample = dataset[0]["audio"]
sample = dataset[0]["audio"].get_all_samples().data
# Load model:
model_id = "openai/whisper-large-v2"
@ -1133,8 +1133,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
num_beams=1,
)
transcription_non_ass = pipe(sample.copy(), generate_kwargs={"assistant_model": assistant_model})["text"]
transcription_ass = pipe(sample)["text"]
transcription_ass = pipe(sample.clone().detach(), generate_kwargs={"assistant_model": assistant_model})["text"]
transcription_non_ass = pipe(sample)["text"]
self.assertEqual(transcription_ass, transcription_non_ass)
self.assertEqual(
@ -1422,13 +1422,13 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
sample = dataset[0]["audio"].get_all_samples().data
# prompt the model to misspell "Mr Quilter" as "Mr Quillter"
whisper_prompt = "Mr. Quillter."
prompt_ids = pipe.tokenizer.get_prompt_ids(whisper_prompt, return_tensors="pt").to(torch_device)
unprompted_result = pipe(sample.copy())["text"]
unprompted_result = pipe(sample.clone().detach())["text"]
prompted_result = pipe(sample, generate_kwargs={"prompt_ids": prompt_ids})["text"]
# fmt: off