-
Notifications
You must be signed in to change notification settings - Fork 17
/
train_model_ANN_audio.py
60 lines (43 loc) · 1.75 KB
/
train_model_ANN_audio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import pickle as cPickle
import tensorflow as tf
import model_audio_ANN
import numpy as np
pickle_path = './pickle_data/'
print("loading pickle files for ANN")
with open(pickle_path+"train_data_22k_org.pickle", "rb") as input_file:
x_train = cPickle.load(input_file)
with open(pickle_path+"train_labels_22k_org.pickle", "rb") as input_file:
y_train = cPickle.load(input_file)
with open(pickle_path+"test_data_22k_org.pickle", "rb") as input_file:
x_test = cPickle.load(input_file)
with open(pickle_path+"test_labels_22k_org.pickle", "rb") as input_file:
y_test = cPickle.load(input_file)
with open(pickle_path + 'valid_data_22k_org.pickle', 'rb') as input_file:
x_valid = cPickle.load(input_file)
with open(pickle_path + "valid_labels_22k_org.pickle", "rb") as input_file:
y_valid = cPickle.load(input_file)
x_train = np.row_stack([x_train, x_test])
y_train = np.row_stack([y_train, y_test])
ann_model_dir = './model/ANN/'
##############
# Train ANN ##
##############
NUM_EPOCHS = 500
BATCH_SIZE = 64
MODEL = model_audio_ANN.build_tflearn_ann(x_train.shape[1])
MODEL.fit(x_train, y_train, n_epoch=NUM_EPOCHS,
shuffle=True,
validation_set=(x_valid, y_valid),
show_metric=True,
batch_size=BATCH_SIZE)
MODEL.save(ann_model_dir+'Bee_audio_ANN.tfl')
print(MODEL.evaluate(x_test, y_test))
print(MODEL.evaluate(x_train, y_train))
# tf.reset_default_graph()
# ann_model_dir = './model/ANN_new/Bee_audio_ANN.tfl'
# ann_model = model_audio_ANN.build_tflearn_ann(x_test.shape[1])
# ann_model.load(ann_model_dir, weights_only=True, create_new_session = False)
# print(ann_model.evaluate(x_test, y_test))
# print(ann_model.evaluate(x_train, y_train))
# validation_acc = ann_model.evaluate(x_test, y_test)
# print(validation_acc)