forked from Grouper/Cuttlefish
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rf_model.py
62 lines (52 loc) · 1.88 KB
/
rf_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from sklearn.ensemble import RandomForestClassifier
import csv
import random
def clean_data(data):
# converts to floats and adds columns for differences
return [[float(v) for v in row] + [float(row[i])-float(row[i+10]) for i in range(0,10)] for row in data[1:]]
data = []
result = []
with open('training_data.csv', 'rU') as csvfile:
rowreader = csv.reader(csvfile, delimiter=',')
for row in rowreader:
data.append(row[1:11]+row[12:-1])
result.append(row[-1])
result=[0 if v=='FALSE' else 1 for v in result[1:]]
data=clean_data(data)
# This is pretty silly, it creates random forests with 3% cross-validation and just looks for the one
# that performs best on the cv data. This data seems to defy most methods.
best_score=0.0
best_model=None
# This takes ages, but you can reduce the number of models it builds to make it faster
for outer in range(500):
tr_data=[]
tr_result=[]
cv_data=[]
cv_result=[]
for i in range(0,len(data)):
if (random.random()<0.03):
cv_data.append(data[i])
cv_result.append(result[i])
else:
tr_data.append(data[i])
tr_result.append(result[i])
clf = RandomForestClassifier(n_estimators=random.randint(10,30),
criterion="entropy" if random.random()<0.5 else "gini",
max_features="auto" if random.random()<0.5 else None);
clf.fit(tr_data, tr_result)
match_count=0
for i in range(0, len(cv_data)):
if clf.predict(cv_data[i])[0]==cv_result[i]: match_count+=1
if (float(match_count)/len(cv_data))>best_score:
best_score=float(match_count)/len(cv_data)
best_model=clf
test=[]
with open('test_data.csv', 'rU') as csvfile:
rowreader = csv.reader(csvfile, delimiter=',')
for row in rowreader:
test.append(row)
header=test[0]
td=clean_data([row[1:11]+row[12:-1] for row in test])
print ','.join(header)+',members_became_friends'
for i in range(0,len(td)):
print ','.join(test[i+1])+('FALSE' if best_model.predict(td[i])[0]==0 else 'TRUE')