-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_classifier.py
51 lines (35 loc) · 1.53 KB
/
train_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import pickle
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
def main():
"""Program trains random forest classifier to classify hand landmarks.
After training model is saved to data/model.pickle"""
# load our (x, y) coordinates and labels
with open("data/data.pickle", "rb") as f:
data_dict = pickle.load(f)
# shape of data: (n_samples, 21, 2) where 21 is landmarks and 2 is (x, y)
data = np.array(data_dict["data"])
# shape of labels: (n_samples,)
labels = np.array(data_dict["labels"])
# reshape data from (n_samples, 21, 2) to (n_samples, 42)
data = np.reshape(data, (data.shape[0], -1))
# split and shuffle data and labels in 80%/20% proportion
train_data, test_data, train_labels, test_labels = train_test_split(
data, labels, test_size=0.2, stratify=labels, shuffle=True
)
# initialize random forest model with default parameters
model = RandomForestClassifier(n_estimators=100, max_depth=7)
# train model
model.fit(train_data, train_labels)
# check how model performs of unseen data
pred_labels = model.predict(test_data)
# get the accuracy
score = accuracy_score(test_labels, pred_labels)
# Correctly classified 94.87% of samples in my case
print(f"Correctly classified {score*100:.2f}% of samples")
with open("data/model.pickle", "wb") as f:
pickle.dump(model, f)
if __name__ == "__main__":
main()