mirror of
				https://github.com/huggingface/transformers.git
				synced 2025-10-31 00:54:36 +08:00 
			
		
		
		
	Compare commits
	
		
			12 Commits
		
	
	
		
			v4.20.1
			...
			arijitx/wa
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| d179f103a3 | |||
| 1eca8756e1 | |||
| dce44b0ca0 | |||
| b48bc34f30 | |||
| 69629f6a7f | |||
| 24da9d8fe7 | |||
| e7657f6f2c | |||
| da37b595f4 | |||
| 3c0b79891d | |||
| acf3df8004 | |||
| 45180106f9 | |||
| 899536497a | 
| @ -216,3 +216,34 @@ PYTHONPATH=../../../src deepspeed --num_gpus 4 run_pretrain.py \ | ||||
| --fp16 \ | ||||
| --deepspeed ds_config_wav2vec2_zero2.json \ | ||||
| ``` | ||||
|  | ||||
|  | ||||
| ### Forced Alignment | ||||
|  | ||||
| Character level forced alignment for audio and text pairs with wav2vec2 models finetuned on ASR task for a specific language. | ||||
| Inspired by [this](https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html) Pytorch tutorial. | ||||
|  | ||||
| #### Input Formats | ||||
|  | ||||
|     Input format in script.txt              Input format in wavs directroy | ||||
|     0000    sentence1                       0000.wav | ||||
|     0001    sentence2                       0001.wav | ||||
|      | ||||
| #### Output Format | ||||
|  | ||||
| Output directory will contain 0000.txt and 0001.txt. Each file will have format like below | ||||
|  | ||||
|     char    score   start_ms    end_ms | ||||
|     h       0.25    1440        1520 | ||||
|      | ||||
| #### Run command | ||||
|  | ||||
| ``` | ||||
| python alignment.py  \ | ||||
| --model_name="arijitx/wav2vec2-xls-r-300m-bengali" \ | ||||
| --wav_dir="./wavs" | ||||
| --text_file="script.txt" \ | ||||
| --input_wavs_sr=48000 \ | ||||
| --output_dir="./out_alignment" \ | ||||
| --cuda | ||||
| ``` | ||||
|  | ||||
							
								
								
									
										224
									
								
								examples/research_projects/wav2vec2/alignment.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										224
									
								
								examples/research_projects/wav2vec2/alignment.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,224 @@ | ||||
| # Parts of the code are adapted from the snippets provided in the TorchAudio Wav2Vec forced alignment tutorial. | ||||
| # The full tutorial can be found here: https://pytorch.org/audio/stable/tutorials/forced_alignment_tutorial.html | ||||
|  | ||||
| import argparse | ||||
| import os | ||||
| from dataclasses import dataclass | ||||
|  | ||||
| import torch | ||||
| import torchaudio | ||||
| from tqdm import tqdm | ||||
|  | ||||
| from transformers import AutoConfig, AutoModelForCTC, AutoProcessor | ||||
|  | ||||
|  | ||||
| class Wav2Vec2Aligner: | ||||
|     def __init__(self, model_name, input_wavs_sr, cuda): | ||||
|         self.cuda = cuda | ||||
|         self.config = AutoConfig.from_pretrained(model_name) | ||||
|         self.model = AutoModelForCTC.from_pretrained(model_name) | ||||
|         self.model.eval() | ||||
|         if self.cuda: | ||||
|             self.model.to(device="cuda") | ||||
|         self.processor = AutoProcessor.from_pretrained(model_name) | ||||
|         self.resampler = torchaudio.transforms.Resample(input_wavs_sr, 16_000) | ||||
|         blank_id = 0 | ||||
|         vocab = list(self.processor.tokenizer.get_vocab().keys()) | ||||
|         for i in range(len(vocab)): | ||||
|             if vocab[i] == "[PAD]" or vocab[i] == "<pad>": | ||||
|                 blank_id = i | ||||
|         print("Blank Token id [PAD]/<pad>", blank_id) | ||||
|         self.blank_id = blank_id | ||||
|  | ||||
|     def speech_file_to_array_fn(self, wav_path): | ||||
|         speech_array, sampling_rate = torchaudio.load(wav_path) | ||||
|         speech = self.resampler(speech_array).squeeze().numpy() | ||||
|         return speech | ||||
|  | ||||
|     def align_single_sample(self, item): | ||||
|         blank_id = self.blank_id | ||||
|         transcript = "|".join(item["sent"].split(" ")) | ||||
|         if not os.path.isfile(item["wav_path"]): | ||||
|             print(item["wav_path"], "not found in wavs directory") | ||||
|  | ||||
|         speech_array = self.speech_file_to_array_fn(item["wav_path"]) | ||||
|         inputs = self.processor(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True) | ||||
|         if self.cuda: | ||||
|             inputs = inputs.to(device="cuda") | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             logits = self.model(inputs.input_values).logits | ||||
|  | ||||
|         # get the emission probability at frame level | ||||
|         emissions = torch.log_softmax(logits, dim=-1) | ||||
|         emission = emissions[0].cpu().detach() | ||||
|  | ||||
|         # get labels from vocab | ||||
|         labels = ([""] + list(self.processor.tokenizer.get_vocab().keys()))[ | ||||
|             :-1 | ||||
|         ]  # logits don't align with the tokenizer's vocab | ||||
|  | ||||
|         dictionary = {c: i for i, c in enumerate(labels)} | ||||
|         tokens = [] | ||||
|         for c in transcript: | ||||
|             if c in dictionary: | ||||
|                 tokens.append(dictionary[c]) | ||||
|  | ||||
|         def get_trellis(emission, tokens, blank_id=0): | ||||
|             """ | ||||
|             Build a trellis matrix of shape (num_frames + 1, num_tokens + 1) | ||||
|             that represents the probabilities of each source token being at a certain time step | ||||
|             """ | ||||
|             num_frames = emission.size(0) | ||||
|             num_tokens = len(tokens) | ||||
|  | ||||
|             # Trellis has extra diemsions for both time axis and tokens. | ||||
|             # The extra dim for tokens represents <SoS> (start-of-sentence) | ||||
|             # The extra dim for time axis is for simplification of the code. | ||||
|             trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf")) | ||||
|             trellis[:, 0] = 0 | ||||
|             for t in range(num_frames): | ||||
|                 trellis[t + 1, 1:] = torch.maximum( | ||||
|                     # Score for staying at the same token | ||||
|                     trellis[t, 1:] + emission[t, blank_id], | ||||
|                     # Score for changing to the next token | ||||
|                     trellis[t, :-1] + emission[t, tokens], | ||||
|                 ) | ||||
|             return trellis | ||||
|  | ||||
|         trellis = get_trellis(emission, tokens, blank_id) | ||||
|  | ||||
|         @dataclass | ||||
|         class Point: | ||||
|             token_index: int | ||||
|             time_index: int | ||||
|             score: float | ||||
|  | ||||
|         def backtrack(trellis, emission, tokens, blank_id=0): | ||||
|             """ | ||||
|             Walk backwards from the last (sentence_token, time_step) pair to build the optimal sequence alignment path | ||||
|             """ | ||||
|             # Note: | ||||
|             # j and t are indices for trellis, which has extra dimensions | ||||
|             # for time and tokens at the beginning. | ||||
|             # When referring to time frame index `T` in trellis, | ||||
|             # the corresponding index in emission is `T-1`. | ||||
|             # Similarly, when referring to token index `J` in trellis, | ||||
|             # the corresponding index in transcript is `J-1`. | ||||
|             j = trellis.size(1) - 1 | ||||
|             t_start = torch.argmax(trellis[:, j]).item() | ||||
|  | ||||
|             path = [] | ||||
|             for t in range(t_start, 0, -1): | ||||
|                 # 1. Figure out if the current position was stay or change | ||||
|                 # Note (again): | ||||
|                 # `emission[J-1]` is the emission at time frame `J` of trellis dimension. | ||||
|                 # Score for token staying the same from time frame J-1 to T. | ||||
|                 stayed = trellis[t - 1, j] + emission[t - 1, blank_id] | ||||
|                 # Score for token changing from C-1 at T-1 to J at T. | ||||
|                 changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] | ||||
|  | ||||
|                 # 2. Store the path with frame-wise probability. | ||||
|                 prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() | ||||
|                 # Return token index and time index in non-trellis coordinate. | ||||
|                 path.append(Point(j - 1, t - 1, prob)) | ||||
|  | ||||
|                 # 3. Update the token | ||||
|                 if changed > stayed: | ||||
|                     j -= 1 | ||||
|                     if j == 0: | ||||
|                         break | ||||
|             else: | ||||
|                 raise ValueError("Failed to align") | ||||
|             return path[::-1] | ||||
|  | ||||
|         path = backtrack(trellis, emission, tokens) | ||||
|  | ||||
|         @dataclass | ||||
|         class Segment: | ||||
|             label: str | ||||
|             start: int | ||||
|             end: int | ||||
|             score: float | ||||
|  | ||||
|             def __repr__(self): | ||||
|                 return f"{self.label}\t{self.score:4.2f}\t{self.start*20:5d}\t{self.end*20:5d}" | ||||
|  | ||||
|             @property | ||||
|             def length(self): | ||||
|                 return self.end - self.start | ||||
|  | ||||
|         def merge_repeats(path): | ||||
|             """ | ||||
|             Merge repeated tokens into a single segment. Note: this shouldn't affect repeated characters from the | ||||
|             original sentences (e.g. `ll` in `hello`) | ||||
|             """ | ||||
|             i1, i2 = 0, 0 | ||||
|             segments = [] | ||||
|             while i1 < len(path): | ||||
|                 while i2 < len(path) and path[i1].token_index == path[i2].token_index: | ||||
|                     i2 += 1 | ||||
|                 score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) | ||||
|                 segments.append( | ||||
|                     Segment( | ||||
|                         transcript[path[i1].token_index], | ||||
|                         path[i1].time_index, | ||||
|                         path[i2 - 1].time_index + 1, | ||||
|                         score, | ||||
|                     ) | ||||
|                 ) | ||||
|                 i1 = i2 | ||||
|             return segments | ||||
|  | ||||
|         segments = merge_repeats(path) | ||||
|         with open(item["out_path"], "w") as out_align: | ||||
|             for seg in segments: | ||||
|                 out_align.write(str(seg) + "\n") | ||||
|  | ||||
|     def align_data(self, wav_dir, text_file, output_dir): | ||||
|  | ||||
|         if not os.path.exists(output_dir): | ||||
|             os.makedirs(output_dir) | ||||
|  | ||||
|         # load text file | ||||
|         lines = open(text_file, encoding="utf8").readlines() | ||||
|  | ||||
|         items = [] | ||||
|         for line in lines: | ||||
|             if len(line.strip().split("\t")) != 2: | ||||
|                 print("Script must be in format: 00001  this is my sentence") | ||||
|                 exit() | ||||
|  | ||||
|             wav_name, sentence = line.strip().split("\t") | ||||
|             wav_path = os.path.join(wav_dir, wav_name + ".wav") | ||||
|             out_path = os.path.join(output_dir, wav_name + ".txt") | ||||
|  | ||||
|             items.append({"sent": sentence, "wav_path": wav_path, "out_path": out_path}) | ||||
|         print("Number of samples found in script file", len(items)) | ||||
|  | ||||
|         for item in tqdm(items): | ||||
|             self.align_single_sample(item) | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     parser = argparse.ArgumentParser() | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--model_name", type=str, default="arijitx/wav2vec2-xls-r-300m-bengali", help="wav2vec model name" | ||||
|     ) | ||||
|     parser.add_argument("--wav_dir", type=str, default="./wavs", help="directory containing wavs") | ||||
|     parser.add_argument("--text_file", type=str, default="script.txt", help="file containing text") | ||||
|     parser.add_argument("--input_wavs_sr", type=int, default=16000, help="sampling rate of input audios") | ||||
|     parser.add_argument( | ||||
|         "--output_dir", type=str, default="./out_alignment", help="output directory containing the alignment files" | ||||
|     ) | ||||
|     parser.add_argument("--cuda", action="store_true") | ||||
|  | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     aligner = Wav2Vec2Aligner(args.model_name, args.input_wavs_sr, args.cuda) | ||||
|     aligner.align_data(args.wav_dir, args.text_file, args.output_dir) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
							
								
								
									
										8
									
								
								examples/research_projects/wav2vec2/run_alignment.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								examples/research_projects/wav2vec2/run_alignment.sh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,8 @@ | ||||
| #!/usr/bin/env bash | ||||
| python alignment.py  \ | ||||
| --model_name="arijitx/wav2vec2-xls-r-300m-bengali" \ | ||||
| --wav_dir="./wavs" \ | ||||
| --text_file="script.txt" \ | ||||
| --input_wavs_sr=48000 \ | ||||
| --output_dir="./out_alignment" \ | ||||
| --cuda | ||||
		Reference in New Issue
	
	Block a user
	