-
Notifications
You must be signed in to change notification settings - Fork 0
/
cli.py
150 lines (115 loc) · 3.88 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from typing import List
import numpy as np
from search_clustering.client import ElasticClient
from search_clustering.presets import get_demo_preset
es = ElasticClient("http://localhost:9200")
pipe_knn, pipe_temp, pipe_none = get_demo_preset()
commands: str = """
query [-t] <q>: query and cluster documents for query q, -t for temporal clustering
cluster [-t] <c>: hierarchically cluster the cluster with index c, -t for temporal clustering
list <c>: list the documents in the cluster with index c
show <d>: print the formatted document with ID d
detail <d>: print the entire document with ID d including all metadata
quit: self-explanatory
"""
cache: dict = {"D_q": None, "clusters": None, "labels": None, "query": None}
def filter_cache(c: str):
if cache["D_q"] == None:
print("Query documents first")
idx = int(c)
if idx == max(cache["clusters"]) + 1:
idx = -1
return [cache["D_q"][i] for i in np.where(cache["clusters"] == idx)[0]]
def list_cluster(D_c: List[dict]) -> None:
for d in D_c:
print(d["_id"], d["_source"]["title"])
def print_clusters() -> None:
print("Clusters:")
for i, label in enumerate(cache["labels"]):
print(i, label)
def show(_id: str) -> None:
for doc in cache["D_q"]:
if doc["_id"] == _id:
print(doc["_source"]["title"])
print(doc["_source"]["url"], end="\n\n")
print(doc["_source"]["body"])
return
print("Invalid ID")
def detail(_id: str) -> None:
for doc in cache["D_q"]:
if doc["_id"] == _id:
print(doc)
return
print("Invalid ID")
def cluster_knn(D: List[dict]) -> None:
if len(D) >= 8:
docs, clusters, labels, _ = pipe_knn.fit_transform(
D, verbose=True, visualize=False, query=cache["query"]
)
elif len(D) > 0:
docs, clusters, labels, _ = pipe_none.fit_transform(
D, verbose=False, visualize=False, query=cache["query"]
)
else:
return
cache["D_q"] = docs
cache["clusters"] = clusters
cache["labels"] = labels
print_clusters()
def cluster_temp(D) -> None:
docs, clusters, labels = pipe_temp.fit_transform(D, verbose=True, visualize=False)
cache["D_q"] = docs
cache["clusters"] = clusters
cache["labels"] = labels
print_clusters()
def query(q, temporal=False) -> None:
cache["query"] = q
results = es.search(q, size=10_000)
D_q = [res for res in results if "\n" in res["_source"]["body"]][:500]
print("Retrieved", len(D_q), "results")
if len(D_q) == 0:
return
if not temporal:
cluster_knn(D_q)
else:
cluster_temp(D_q)
def parse_input(inp_str: str = ""):
if inp_str == "":
return
inp = inp_str.split(" ")
func = inp.pop(0)
if func == "quit":
exit()
elif func == "help":
print(commands)
elif func == "query":
if len(inp) > 2 and inp[0] == "-t":
query(" ".join(inp[1:]), True)
if inp[0] != "-t":
query(" ".join(inp))
elif func == "cluster":
if len(inp) == 1 and inp[0] != "-t":
D_c = filter_cache(inp[0])
cluster_knn(D_c)
elif len(inp) == 2 and inp[0] == "-t":
D_c = filter_cache(inp[0])
cluster_temp(D_c)
elif func == "list":
if len(inp) == 1:
D_c = filter_cache(inp[0])
list_cluster(D_c)
elif func == "show":
if len(inp) == 1:
show(inp[0])
elif func == "detail":
if len(inp) == 1:
detail(inp[0])
elif func == "cache":
print(cache)
else:
print("Invalid command (Type 'help' to see all available commands)")
if __name__ == "__main__":
print("Connected to http://localhost:9200")
print("Type 'help' to see all available commands")
while True:
parse_input(input("> "))