-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
56 lines (45 loc) · 1.56 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
=======================================================================
This script is for training PHIEmbed. It takes a CSV file corresponding
to the training dataset as input and outputs a trained scikit-learn
random forest classifier (serialized in joblib format).
@author Mark Edward M. Gonzales
=======================================================================
"""
import argparse
import joblib
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input",
required=True,
help="Path to the training dataset",
)
parser.add_argument(
"--threads",
help="Number of threads to be used for training (default: -1, that is, use all threads)",
type=int,
default=-1,
)
args = parser.parse_args()
train = pd.read_csv(
args.input,
header=None,
names=["Protein ID", "Host"] + [str(i) for i in range(1, 1025)],
)
X_train = train.loc[:, train.columns.isin([str(i) for i in range(1, 1025)])]
y_train = train.loc[:, train.columns.isin(["Host"])]
assert X_train.shape[1] == 1024 and y_train.shape[1] == 1
clf = RandomForestClassifier(
class_weight="balanced",
max_features="sqrt",
min_samples_leaf=1,
min_samples_split=2,
n_estimators=150,
n_jobs=args.threads,
verbose=True,
)
clf.fit(X_train.values, y_train.values.ravel())
joblib.dump(clf, "phiembed_trained.joblib.gz", compress=True)