-
Notifications
You must be signed in to change notification settings - Fork 1
/
main1-CNN.py
25 lines (21 loc) · 913 Bytes
/
main1-CNN.py
1
import tensorflow as tfimport numpy as npfrom CNN import CNNfrom AutoEncoder import AutoEncoderfrom ReadDataset import readnp.set_printoptions(threshold=np.nan)config = tf.ConfigProto()config.gpu_options.allow_growth = Trueconfig.gpu_options.allocator_type = 'BFC'sess = tf.InteractiveSession(config=config)training_data, training_label = read()test_data, test_label = read(dataset="testing")with tf.device('/cpu:0'): # One hot encoding training_label = sess.run(tf.one_hot(indices=training_label, depth=max(training_label + 1), dtype=np.float64)) test_label = sess.run(tf.one_hot(indices=test_label, depth=max(test_label + 1), dtype=np.float64))cnn = CNN(epochs=2,batch_size=800,isSave=False)# output layery_ = cnn.initial_mlp_network()cnn_accuricy, cnn_predict_label = cnn.calculate_session(y_, training_data, training_label, test_data, test_label)print(cnn_accuricy,cnn_predict_label)