-
Notifications
You must be signed in to change notification settings - Fork 0
/
index.py
214 lines (185 loc) · 10.2 KB
/
index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
from typing import List
import click
import loguru
from click import Context
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_core.documents import Document
from opensearchpy import OpenSearch, RequestsHttpConnection
from pycomfort.config import configure_logger, LogLevel, LOG_LEVELS, load_environment_keys
from pycomfort.logging import timing
from typing import Optional
from FlagEmbedding import BGEM3FlagModel
from hybrid_search.novel_embeddings import BgeM3Embeddings
from hybrid_search.opensearch_hybrid_search import OpenSearchHybridSearch
import logging
import torch
from opensearchpy import OpenSearch, exceptions
from opensearchpy import OpenSearch, RequestsHttpConnection
@click.group(invoke_without_command=True)
@click.pass_context
def app(ctx: Context):
if ctx.invoked_subcommand is None:
click.echo('Running the default command...')
main()
pass
@timing("indexing")
def index_function(data_path: str, glob_pattern: str, embedding: str, url: str, user: str, password: str, pipeline_name: str, index_name: str, device: str, space: str, logger: Optional["loguru.Logger"] = None):
logger.info(f"indexing from {data_path} using pattern: {glob_pattern} \n using {embedding} with URL {url} \n USER: {user} PASSWORD: {password} \n index_name {index_name} ")
loader = DirectoryLoader(data_path, glob=glob_pattern, loader_cls=TextLoader)
if device == "detect":
device = "cuda" if torch.cuda.is_available() else "cpu"
docs: list[Document] = loader.load()
for i, doc in enumerate(docs):
doc.metadata['page_id'] = doc.metadata['source'].split('/')[-1].split('.')[0]
model_kwargs = {"device": device, "trust_remote_code": True}
encode_kwargs = {"normalize_embeddings": True}
if "bge-m3" in embedding:
embeddings = BgeM3Embeddings(use_fp16=True, device=device)
elif "bge" in embedding:
embeddings = HuggingFaceBgeEmbeddings(
model_name=embedding,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
else:
embeddings = HuggingFaceEmbeddings(
model_name=embedding,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
logger.info("starting to indexing")
docsearch: OpenSearchHybridSearch = OpenSearchHybridSearch.create(url, index_name, embeddings,
login=user, password=password,
pipeline_name=pipeline_name,
documents=docs,
space_type=space
)
# Prepare the pipeline with configurable arguments
if not docsearch.check_pipeline_exists():
logger.info(f"hybrid search pipeline does not exist, creating it for {url}")
docsearch.create_pipeline(url)
logger.info(f"Finished indexing with index name {index_name} and embedding {embedding} of data-path {data_path}")
# Main CLI command
@app.command("main")
@click.option('--data-path', show_default=True, default='data/tacutopapers_test_rsids_10k/', help='Path to the data directory.')
@click.option('--glob-pattern', show_default=True, default='*.txt', help='Glob pattern for files.')
@click.option('--embedding', show_default=True, default='Alibaba-NLP/gte-large-en-v1.5', help='Type of embedding to use.') #can also use BAAI/bge-en-icl , BAAI/bge-large-en-v1.5 and also be allenai/specter2_aug2023refresh
@click.option('--url', show_default=True, default='https://localhost:9200', help='URL for the pipeline.')
@click.option('--user', show_default=True, default='admin', help='Username for the pipeline.')
@click.option('--password', show_default=True, default='Mind2@Mind', help='Password for the pipeline.')
@click.option('--pipeline-name', show_default=True, default='norm-pipeline', help='Name of the pipeline.')
@click.option('--index_name', show_default=True, default='index-gte-test_rsids_10k', help='Name of index')
@click.option('--device', show_default=True, default='detect', help='Device to use')
@click.option('--space', type=click.Choice(["cosinesimil", "l2", "innerproduct", "l1", "linf"], False), default='l2', help='Space to use for OpenSearch')
@click.option('--log_level', type=click.Choice(LOG_LEVELS, case_sensitive=False), default=LogLevel.DEBUG.value, help="logging level")
def main(data_path: str, glob_pattern: str, embedding: str, url: str, user: str, password: str, pipeline_name: str, index_name: str, device: str, space: str, log_level: str):
logger = configure_logger(log_level)
logger.add("./logs/hybrid_index_{time}.log")
load_environment_keys(usecwd=True)
if device == "detect":
device = "cuda" if torch.cuda.is_available() else "cpu"
return index_function(data_path, glob_pattern, embedding, url, user, password, pipeline_name, index_name, device, space, logger)
@app.command("test_connection")
@click.option('--url', default='https://localhost:9200', help='URL of the OpenSearch cluster')
@click.option('--username', default='admin', help='Username for the OpenSearch cluster')
@click.option('--password', default='Mind2@Mind', help='Password for the OpenSearch cluster')
@click.option('--use-ssl', default=True, type=bool, help='Use SSL for connection')
@click.option('--ssl-show-warn', default=False, type=bool, help='Show SSL warnings')
def test_opensearch(url: str, username: str, password: str, use_ssl: bool, ssl_show_warn: bool):
""" Connects to OpenSearch and adds a test index with test data. """
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logging.getLogger('opensearchpy').setLevel(logging.DEBUG)
""" Connects to OpenSearch and adds a test index with test data. """
try:
# Initialize the OpenSearch client
client = OpenSearch(
hosts=[url],
http_auth=(username, password),
use_ssl=use_ssl,
verify_certs=False,
ssl_assert_hostname=False,
ssl_show_warn=ssl_show_warn,
connection_class=RequestsHttpConnection,
trust_env=True
)
# Create a test index
index_name = 'test_index'
client.indices.create(index=index_name, ignore=400)
# Add test data
test_data = {"name": "Test Name", "description": "This is a test entry."}
response = client.index(index=index_name, body=test_data)
# Print success message
if response.get('result') in ['created', 'updated']:
click.echo("Test data added successfully!")
else:
click.echo("Failed to add test data.")
except exceptions.OpenSearchException as e:
click.echo(f"Error interacting with OpenSearch: {e}")
@app.command("delete_index")
@click.option('--url', default='https://localhost:9200', help='URL of the OpenSearch cluster')
@click.option('--username', default='admin', help='Username for the OpenSearch cluster')
@click.option('--password', default='Mind2@Mind', help='Password for the OpenSearch cluster')
@click.option('--use-ssl', default=True, type=bool, help='Use SSL for connection')
@click.option('--ssl-show-warn', default=False, type=bool, help='Show SSL warnings')
@click.option('--index-name', default='test_index', help='Name of the index to be deleted')
def delete_index(url: str, username: str, password: str, use_ssl: bool, ssl_show_warn: bool, index_name: str):
""" Deletes a specified index from the OpenSearch cluster. """
try:
# Initialize the OpenSearch client
client = OpenSearch(
hosts=[url],
http_auth=(username, password),
use_ssl=use_ssl,
verify_certs=False,
ssl_assert_hostname=False,
ssl_show_warn=ssl_show_warn,
connection_class=RequestsHttpConnection,
trust_env=True
)
# Delete the specified index
response = client.indices.delete(index=index_name, ignore=[400, 404])
# Print success message
if response.get('acknowledged', False):
click.echo(f"Index '{index_name}' deleted successfully.")
else:
click.echo(f"Index '{index_name}' was not found or could not be deleted.")
except exceptions.OpenSearchException as e:
click.echo(f"Error interacting with OpenSearch: {e}")
@app.command("gte")
@click.pass_context
def gte_command(ctx, *args, **kwargs):
# You can set default values for any option you want to override
kwargs['embedding'] = 'Alibaba-NLP/gte-large-en-v1.5'
if 'index_name' not in kwargs:
index_name = "index-gte-test_rsids_10k"
print(f"no index name set, setting up default as {index_name}")
kwargs['index_name'] = index_name
# Call the main command with the new defaults
ctx.invoke(main, *args, **kwargs)
@app.command("bge")
@click.pass_context
def bge_command(ctx, *args, **kwargs):
# You can set default values for any option you want to override
kwargs['embedding'] = 'BAAI/bge-large-en-v1.5' #'BAAI/bge-base-en-v1.5'
if 'index_name' not in kwargs:
index_name = "index-bge-test_rsids_10k"
print(f"no index name set, setting up default as {index_name}")
kwargs['index_name'] = index_name
# Call the main command with the new defaults
ctx.invoke(main, *args, **kwargs)
@app.command("specter2")
@click.pass_context
def specter_command(ctx, *args, **kwargs):
# You can set default values for any option you want to override
kwargs['embedding'] = 'allenai/specter2_base' #'allenai/specter2_aug2023refresh'
if 'index_name' not in kwargs:
index_name = "index-specter2-test_rsids_10k"
print(f"no index name set, setting up default as {index_name}")
kwargs['index_name'] = index_name
# Call the main command with the new defaults
ctx.invoke(main, *args, **kwargs)
if __name__ == '__main__':
app()