-
Notifications
You must be signed in to change notification settings - Fork 0
/
cross_validation.py
39 lines (33 loc) · 1.27 KB
/
cross_validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
from nearest_neighbors import KNNClassifier
def kfold(n, n_folds):
elements = np.arange(0, n)
slices = np.array_split(elements, n_folds)
boundaries = [x[i] for x in slices for i in [0, -1]]
boundary_idx = np.searchsorted(elements, boundaries).reshape(-1, 2)
ans = [(np.concatenate([elements[:x[0]],
np.setdiff1d(elements[x[0]:x[1] + 1], b, assume_unique=True),
elements[x[1] + 1:]]), b)
for b, x in zip(slices, boundary_idx)]
return ans
def knn_cross_val_score(X, y, k_list, score, cv, **kwargs):
if kwargs.get('k', -1) != -1:
del kwargs['k']
if score == 'accuracy':
ans = []
if cv is None:
cv = kfold(X.shape[0], 3)
# print(cv, '<- OMG LOOK AT THAT DANK CV!!!@_@')
k_dict = {}
for k in k_list:
knn_classifier = KNNClassifier(k=k, **kwargs)
ans = []
for pair in cv:
knn_classifier.fit(X[pair[0]], y[pair[0]])
X_pred = knn_classifier.predict(X[pair[1]])
ans.append(accuracy(X_pred, y[pair[1]]))
k_dict[k] = np.array(ans)
return k_dict
return None
def accuracy(a, b):
return a[a == b].shape[0] / a.shape[0]