-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_county.py
57 lines (47 loc) · 1.65 KB
/
train_county.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import pickle
import csv
from sklearn import svm
def parse_features(string_feature):
# Take in as '[0, 0, 0, 1]' and parse into list
list_features = string_feature[1:len(string_feature)-1]
list_features = list_features.split(', ')
for i in range(len(list_features)):
list_features[i] = float(list_features[i])
return list_features
def run(args):
file_path = args.file_path
model_path = args.model_path
labels_path = args.labels_path
trainX = []
trainY = []
data = {} # map (county, state) to trainX
labels = {} # map (county, state) to trainY
with open(file_path, mode='r') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
data[(row['County'], row['State'])] = parse_features(row['Features'])
with open(labels_path, mode='r') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
labels[(row['county'], row['state'])] = row['turnout rate']
for key in data.keys():
try:
trainY.append(labels[key])
trainX.append(data[key])
except KeyError:
print(str(key) + "not found")
print("Done constructing training data")
clf = svm.SVC(kernel='linear')
clf.fit(trainX, trainY) # training
print("Done training model")
pickle.dump(clf, open(model_path, 'wb'))
print("Saved model successfully")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
'Train model on data')
parser.add_argument('--file_path', type=str, default='social_data/2016_gov_agencies_data_county.csv')
parser.add_argument('--model_path', type=str, default='models/2016_gov_agencies_model')
parser.add_argument('--labels_path', type=str, default='voter_data/2016_voter_data.csv')
args = parser.parse_args()
run(args)