This repository has been archived by the owner on Oct 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
audio_classifier.py
79 lines (69 loc) · 2.93 KB
/
audio_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Load the API (Current warning is related to h5py and has no consequences)
import os
import time
import numpy as np
from inaSpeechSegmenter import Segmenter, seg2csv
from pyAudioAnalysis import audioSegmentation as aS
from db_handler import *
class audio_classifier():
def __init__(self, files, tags = None):
self.media = files
self.results = []
self.times = []
self.segmentation = []
self.results = []
self.tags = tags
self.meta = []
self.algo = "ina" #ina/paa
self.db_server = db_handler()
self.classify()
self.save_results()
self.print_results()
def classify(self):
if self.algo == "ina":
self.seg = Segmenter()
counter = 0
for audioPath in self.media:
startTime = int(round(time.time()))
vid = audioPath.split("/")[-1]
print("### {}/{} Processing {} ###".format(counter, len(self.media), vid))
if self.algo == "ina":
tmp = self.seg(audioPath)
tmp2 = str(tmp)
self.segmentation.append(tmp)
if ("Male" in tmp2 or "Female" in tmp2) and "Music" in tmp2:
self.results.append("Mixed")
elif "Music" in tmp2:
self.results.append("Music")
elif "Male" in tmp2 or "Female" in tmp2:
self.results.append("Speech")
elif self.algo == "paa":
[flagsInd, classesAll, acc, CM] = aS.mtFileClassification(audioPath, "svmSM/svmSM", "svm", False, '')
res = np.array(flagsInd).mean()
if res <= 0.1:
self.results.append("Speech")
elif res >= 0.9:
self.results.append("Music")
else:
self.results.append("Mixed")
endTime = int(round(time.time()))
self.times.append(endTime-startTime)
counter += 1
def save_results(self):
for i, audioPath in enumerate(self.media):
vid = audioPath.split("/")[-1]
data = {'name' : vid, 'path' : audioPath, 'kullanici': user_default, 'class': self.results[i]}
self.db_server.save(db_ac, data, doc_id=audioPath)
def print_results(self):
print("\n\n### Results ###\n")
correct_counter = 0
for i, audioPath in enumerate(self.media):
print(audioPath)
vid = audioPath.split("/")[-1]
if self.tags:
if self.results[i] == self.tags[vid]:
correct_counter += 1
print("# {} # Result: {} # Actual: {} # Time: {}".format(vid, self.results[i], self.tags[vid], self.times[i]))
else:
print("# {} # Result: {} # Time: {}".format(vid, self.results[i], self.times[i]))
print("correct answer rate: {}/{}".format(correct_counter, len(self.media)))