-
Notifications
You must be signed in to change notification settings - Fork 1
/
apply_umap.py
64 lines (57 loc) · 1.8 KB
/
apply_umap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import umap
import time
import pickle
import sys
import json
import os
from sklearn.neighbors import KDTree
def run(**kwargs):
t0 = time.time()
path = os.path.join('data', f'{kwargs["model_tag"]}_')
preprocess_df = pickle.load(open(path + 'preprocess_df.pkl', 'rb'))
umap_mapper = umap.UMAP(
n_components=kwargs['components'],
n_neighbors=kwargs['n_neighbors'],
min_dist=kwargs['min_dist'],
metric=kwargs['metric'], # ideally hellinger, but takes ages
densmap=kwargs['densmap'],
verbose=True
)
mapper = umap_mapper.fit(preprocess_df)
print('UMAP run in', f'{round(time.time() - t0, 2)}s')
len(mapper.embedding_)
t0 = time.time()
nn = KDTree(mapper.embedding_)
if kwargs['neighbor_approach'] == 'radius':
knn_indices = list(map(
lambda x: x.tolist(),
nn.query_radius(
mapper.embedding_,
r=kwargs['threshold'],
return_distance=True,
sort_results=True
)[0]
))
elif kwargs['neighbor_approach'] == 'knn':
knn_indices = list(map(
lambda x: x.tolist(),
nn.query(
mapper.embedding_,
k=kwargs['n_neighbors'],
return_distance=False
)
))
else:
print('Invalid neighbors_approach. Please use either radius or knn.')
sys.exit()
knn_indices = list(map(lambda elem: [x for x in elem[1] if x != elem[0]], enumerate(knn_indices)))
print('Nearest Neighbors run in', f'{round(time.time() - t0, 2)}s')
pickle.dump(mapper, open(path + 'umap.pkl', 'wb'))
pickle.dump(mapper.embedding_, open(path + 'embeddings.pkl', 'wb'))
pickle.dump(knn_indices, open(path + 'knn_indices.pkl', 'wb'))
if __name__ == '__main__':
try:
kwargs = json.load(open('args.local.json'))
except:
kwargs = json.load(open('args.json'))
run(**kwargs['apply_umap'] | kwargs['global'])