-
Notifications
You must be signed in to change notification settings - Fork 0
/
log_regression.py
33 lines (30 loc) · 1.35 KB
/
log_regression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from data_process import pickle_read
from data_generator import DataGenerator, load_all
from joblib import dump, load
import numpy as np
# dump(clf, 'filename.joblib')
# clf = load('filename.joblib')
from sklearn import svm
from sklearn import metrics
from sklearn.linear_model import LogisticRegression
if __name__ == "__main__":
train_set = pickle_read("./data/print_attack/processed/train.pkl")
valid_set = pickle_read("./data/print_attack/processed/valid.pkl")
test_set = pickle_read("./data/print_attack/processed/test.pkl")
name = "logistic_regression"
path_save_model = "./data/models/{}_classifier/{}.joblib"
x_train, y_train = load_all(train_set)
x_valid, y_valid = load_all(valid_set)
x_test, y_test = load_all(test_set)
clf=LogisticRegression(random_state=42)
clf.fit(x_train, y_train)
y_pred = clf.predict(x_train)
print("Train accuracy:",metrics.accuracy_score(y_train, y_pred))
y_pred = clf.predict(x_test)
print("Accuracy:",metrics.accuracy_score(y_test, y_pred))
# clf.predict(np.expand_dims(x_test[0, :], axis=0))
print(path_save_model.format(name.split("_")[0], name))
dump(clf, path_save_model.format(name.split("_")[0], name))
clf_load = load(path_save_model.format(name.split("_")[0], name))
y_pred = clf.predict(x_test)
print("Accuracy:",metrics.accuracy_score(y_test, y_pred))