-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
76 lines (71 loc) · 2.9 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
Module for doing inference using
trained classifiers.
"""
import numpy as np
from classifier_models import get_classifier
from data_process import pickle_read
from data_generator import DataGenerator, load_all
from classifier_models import get_classifier
from sklearn.metrics import precision_recall_fscore_support
from joblib import load
# TODO: modify this detector to accomodate SVM and SVC with kmeans
class AttackDetector():
"""
To do detect attack by ensemble
classfiers.
"""
def __init__(self):
self.classifiers = ["tiny_2_layer", "2_layer_dense"] # which classifiers to use
self.ml_model_names = ["svm_liner", "gmm_spherical", "logistic_regression"]
self.weights = [1 , 1, 1, 0, 1] # weights of each classifer
assert len(self.classifiers + self.ml_model_names) == len(self.weights)
self.weights = np.array(self.weights)
self.models = []
self.init_models()
def init_models(self):
for name in self.classifiers:
classifier = get_classifier(name=name, load_pretrained=True)
self.models.append(classifier)
for name in self.ml_model_names:
classifier = self.get_ml_model(name)
self.models.append(classifier)
def get_ml_model(self, name):
path_save_model = "./data/models/{}_classifier/{}.joblib"
return load(path_save_model.format(name.split("_")[0], name))
def is_attack(self, embedding, threshold=0.9, verbose=False):
embedding = np.expand_dims(embedding, axis=0)
preds = []
for model in self.models:
pred = model.predict(embedding)
pred = pred[0]
if type(pred) == list:
pred = pred[0]
preds.append(pred)
weighted_preds = np.array(preds) * self.weights
mask = np.array(self.weights) != 0
weighted_preds = np.sum(weighted_preds) / np.sum(mask.astype(np.int8))
weighted_preds = weighted_preds[0]
if verbose:
print("Classfiers: ", self.classifiers + self.ml_model_names)
print("Attack Probability:", preds)
print("Weights:", self.weights)
print("Mean Probability", weighted_preds)
print("Is attack:", weighted_preds > threshold)
if weighted_preds > threshold:
return 1 # attack confirmed
return 0 # not an attack
def evaluate(self, threshold):
"""
Returns Precision, Recall and FScore of
the attack detector on test data.
"""
test_set = pickle_read("./data/print_attack/processed/test.pkl")
x_test, y_test = load_all(test_set)
preds = []
for x in x_test:
pred = self.is_attack(x, threshold=threshold)
preds.append(pred)
# print(set(y_test) - set(preds))
preds = np.array(preds)
return precision_recall_fscore_support(y_true=y_test, y_pred=preds, average="binary")