-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_classifier.py
64 lines (55 loc) · 2.59 KB
/
run_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
###########
# Imports #
###########
""" Global """
import os
import argparse
from glob import glob
""" Local """
import constants
from Classifiers.BaselineClassifier import BaselineClassifier
from Classifiers.CustomClassifier import CustomClassifier
from Classifiers.DeepClassifier import DeepClassifier
from Classifiers.BOWClassifier import BOWClassifier
#############
# Functions #
#############
def parse_args():
parser = argparse.ArgumentParser(description="Arguments for running classifier")
parser.add_argument("-ci", "--catalog_images_folder", dest="catalog_images_folder", help="Path to catalog images folder", default=constants.CATALOG_IMAGES_PATH)
parser.add_argument("-qi", "--query_path", dest="query_path", help="Path to query images", default=constants.CLASSIFICATION_QUERY_IMAGES_PATH)
parser.add_argument("-gt", "--ground_truth_path", dest="ground_truth_path", help="Path to ground truth annotation", default=constants.CLASSIFICATION_GROUND_TRUTH_PATH)
parser.add_argument("-clf", "--classifier", dest="classifier", help="Classifier : Baseline, Custom (default), BOW, Deep", default="Custom")
parser.add_argument("--one_query", action="store_true", help="Predict only one query")
args = parser.parse_args()
args.catalog_images_paths = glob(args.catalog_images_folder + "/*")
if not args.one_query:
args.query_images_paths = glob(args.query_path + "/*")
return args
def get_classifier(args):
if args.classifier == "Baseline":
return BaselineClassifier(args.catalog_images_paths)
elif args.classifier == "Custom":
return CustomClassifier(args.catalog_images_paths)
elif args.classifier == "Deep":
return DeepClassifier(args.catalog_images_paths)
elif args.classifier == "BOW":
return BOWClassifier(args.catalog_images_paths)
else:
return CustomClassifier(args.catalog_images_paths)
########
# Main #
########
if __name__ == "__main__":
if not os.path.exists("./files"): os.makedirs("./files")
args = parse_args()
classifier = get_classifier(args)
if args.one_query:
scores = classifier.predict_query(args.query_path)
top5labels = sorted(scores.keys(), key=lambda x: scores[x], reverse=True)[:5]
print("Top 5 labels :")
for k in range(5):
print("{}. Label = {} | Score = {}".format(k+1, top5labels[k], scores[top5labels[k]]))
else:
top1, top3, top5 = classifier.compute_metrics(args.query_images_paths, args.ground_truth_path)
print("Metrics : Top1 = {:.3f}% | Top3 = {:.3f}% | Top5 = {:.3f}%".format(top1 * 100, top3 * 100, top5 * 100))