-
Notifications
You must be signed in to change notification settings - Fork 57
/
wav2vec2_inference.py
67 lines (56 loc) · 3.02 KB
/
wav2vec2_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import soundfile as sf
import torch
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor
# Improvements:
# - convert non 16 khz sample rates
# - inference time log
class Wave2Vec2Inference:
def __init__(self,model_name, hotwords=[], use_lm_if_possible=True, use_gpu=True):
self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
if use_lm_if_possible:
self.processor = AutoProcessor.from_pretrained(model_name)
else:
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.model = AutoModelForCTC.from_pretrained(model_name)
self.model.to(self.device)
self.hotwords = hotwords
self.use_lm_if_possible = use_lm_if_possible
def buffer_to_text(self, audio_buffer):
if len(audio_buffer) == 0:
return ""
inputs = self.processor(torch.tensor(audio_buffer), sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = self.model(inputs.input_values.to(self.device),
attention_mask=inputs.attention_mask.to(self.device)).logits
if hasattr(self.processor, 'decoder') and self.use_lm_if_possible:
transcription = \
self.processor.decode(logits[0].cpu().numpy(),
hotwords=self.hotwords,
#hotword_weight=self.hotword_weight,
output_word_offsets=True,
)
confidence = transcription.lm_score / len(transcription.text.split(" "))
transcription = transcription.text
else:
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.processor.batch_decode(predicted_ids)[0]
confidence = self.confidence_score(logits,predicted_ids)
return transcription, confidence
def confidence_score(self, logits, predicted_ids):
scores = torch.nn.functional.softmax(logits, dim=-1)
pred_scores = scores.gather(-1, predicted_ids.unsqueeze(-1))[:, :, 0]
mask = torch.logical_and(
predicted_ids.not_equal(self.processor.tokenizer.word_delimiter_token_id),
predicted_ids.not_equal(self.processor.tokenizer.pad_token_id))
character_scores = pred_scores.masked_select(mask)
total_average = torch.sum(character_scores) / len(character_scores)
return total_average
def file_to_text(self, filename):
audio_input, samplerate = sf.read(filename)
assert samplerate == 16000
return self.buffer_to_text(audio_input)
if __name__ == "__main__":
print("Model test")
asr = Wave2Vec2Inference("oliverguhr/wav2vec2-large-xlsr-53-german-cv9")
text = asr.file_to_text("test.wav")
print(text)