-
Notifications
You must be signed in to change notification settings - Fork 0
/
PredictFruit.py
86 lines (64 loc) · 2.69 KB
/
PredictFruit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 24 23:09:56 2019
@author: Sahan Dilshan
"""
import numpy as np
from keras.models import load_model
from keras.preprocessing import image
import os
import cv2
from os import listdir
from os.path import isfile, join
from keras.preprocessing.image import ImageDataGenerator
validation_data_dir = 'fruits/validation'
img_width, img_height, img_depth = 32,32,3
batch_size = 64
validation_datagen = ImageDataGenerator(rescale=1./255)
validation_generator = validation_datagen.flow_from_directory(
validation_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical',
shuffle=False)
class_labels = validation_generator.class_indices
class_labels = {v: k for k, v in class_labels.items()}
model = load_model('Trained Models/fruits_fresh_cnn_1.h5')
print("model was successfully loaded.")
def draw_test(name, pred, im, true_label):
BLACK = [0,0,0]
expanded_image = cv2.copyMakeBorder(im, 160, 0, 0, 500 ,cv2.BORDER_CONSTANT,value=BLACK)
cv2.putText(expanded_image, "predited - "+ pred, (20, 60) , cv2.FONT_HERSHEY_SIMPLEX,1, (0,0,255), 2)
cv2.putText(expanded_image, "true - "+ true_label, (20, 120) , cv2.FONT_HERSHEY_SIMPLEX,1, (0,255,0), 2)
cv2.imshow(name, expanded_image)
def getRandomImage(path, img_width, img_height):
"""function loads a random images from a random folder in our test path """
folders = list(filter(lambda x: os.path.isdir(os.path.join(path, x)), os.listdir(path)))
random_directory = np.random.randint(0,len(folders))
path_class = folders[random_directory]
file_path = path + path_class
file_names = [f for f in listdir(file_path) if isfile(join(file_path, f))]
random_file_index = np.random.randint(0,len(file_names))
image_name = file_names[random_file_index]
final_path = file_path + "/" + image_name
return image.load_img(final_path, target_size = (img_width, img_height)), final_path, path_class
files = []
predictions = []
true_labels = []
#predicting images
for i in range(0, 10):
path = './fruits/validation/'
img, final_path, true_label = getRandomImage(path, img_width, img_height)
files.append(final_path)
true_labels.append(true_label)
x = image.img_to_array(img)
x = x * 1./255
x = np.expand_dims(x, axis=0)
images = np.vstack([x])
classes = model.predict_classes(images)
predictions.append(classes)
for i in range(0, len(files)):
img2 = cv2.imread((files[i]))
draw_test("Prediction", class_labels[predictions[i][0]], img2, true_labels[i])
cv2.waitKey(0)
cv2.destroyAllWindows()