-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_template_identification.py
66 lines (50 loc) · 2.92 KB
/
test_template_identification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy
from prettytable import PrettyTable
import cnn_models
from loaddata import load_speakers_data_identification, save_data_to_files
def get_model(mfcc=13, deltas=True, frames=25, train_templates=48):
# Количество шаблонов для тестирования (на одного диктора) (60%:40%)
num_test_templates = int(train_templates // 1.5)
# Количество коэффициентов в одном фрейме
num_features = (mfcc - 1) * 3 if deltas else mfcc - 1
# Количество зарегистрированных дикторов
num_speakers = 50
(x_train, y_train), (x_test, y_test) = \
load_speakers_data_identification(num_frames=frames, num_mfcc=mfcc, use_deltas=deltas,
num_speakers=num_speakers, num_female=9, num_train_templates=train_templates,
num_test_templates=num_test_templates)
save_data_to_files(x_train, 'data/x_train.npy', y_train, 'data/y_train.npy', x_test, 'data/x_test.npy', y_test,
'data/y_test.npy')
# (x_train, y_train), (x_test, y_test) = load_data_from_files('data/x_train.npy', 'data/y_train.npy',
# 'data/x_test.npy', 'data/y_test.npy')
x_train = x_train.reshape(x_train.shape[0], frames, num_features, 1)
x_test = x_test.reshape(x_test.shape[0], frames, num_features, 1)
return cnn_models.get_third_model(input_shape=(frames, num_features, 1), num_classes=num_speakers), \
(x_train, y_train), (x_test, y_test)
def find_best_template_params(epochs):
mfccs = [13, 22, 31]
train_frames = 1200
frames = [25, 50, 100]
deltas = [False, True]
results = []
for mfcc in mfccs:
for delta in deltas:
for frame in frames:
model, (x, y), _ = get_model(mfcc, delta, frame, train_frames // frame)
print('mfcc = {0}, use_deltas = {1}, frames = {2}'.format(mfcc, delta, frame))
results.append(
(cnn_models.k_fold_cross_val_score_f1_micro(x, y, lambda: model, epochs), mfcc, delta, frame))
# cnn_models.grid_search(x_train, y_train, x_test, y_test, lambda: network_model)
print("Results:")
t = PrettyTable(['mfcc', 'use_deltas', 'frames', 'F1-score', 'F1-score std'])
for result in results:
t.add_row([result[1], result[2], result[3], round(result[0].mean(), 4), round(result[0].std(), 4)])
print(t)
best_f1 = (numpy.array([0.0]), 0, 0, 0)
for result in results:
if result[0].mean() > best_f1[0].mean():
best_f1 = (result[0], result[1], result[2], result[3])
print('Best F1-score = {0}; mfcc = {1}, use_deltas = {2}, frames = {3}'.
format(round(best_f1[0].mean(), 4), best_f1[1], best_f1[2], best_f1[3]))
if __name__ == '__main__':
find_best_template_params(epochs=20)