-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
118 lines (88 loc) · 3.2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import argparse
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import StorageContext, Document
import chromadb
from llama_index.llms.openai import OpenAI
embed_model = OpenAIEmbedding()
# create client and a new collection
chroma_client = chromadb.EphemeralClient()
chroma_collection = chroma_client.create_collection("quickstart")
# set up ChromaVectorStore and load in data
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Set up defaults and get API key from environment variable
defaults = {
"api_key": os.getenv("OPENAI_API_KEY"),
"inputs": ".",
}
llm = OpenAI(
api_key=defaults["api_key"],
)
# Function to validate and parse arguments
def validate_and_parse_args(parser):
args = parser.parse_args()
for key, value in vars(args).items():
if not value:
args.__dict__[key] = parser.get_default(key)
if not args.api_key:
parser.error(
"The --api-key argument is required if OPENAI_API_KEY environment variable is not set."
)
if not args.prompt:
parser.error("The --prompt argument is required.")
return args
def main():
# Parse the command line arguments
parser = argparse.ArgumentParser(description="RAG")
parser.add_argument(
"-k",
"--api-key",
type=str,
default=defaults["api_key"],
help="OpenAI API key. Can also be set with OPENAI_API_KEY environment variable.",
)
parser.add_argument("-p", "--prompt", type=str, required=True, help="Prompt.")
parser.add_argument(
"-i",
"--inputs",
type=str,
default=defaults["inputs"],
help="Comma separated list of input files or directories.",
)
args = validate_and_parse_args(parser)
input_files = args.inputs.split(",")
ds = [d for d in input_files if os.path.isdir(d)]
fs = [f for f in input_files if os.path.isfile(f)]
invalid = [i for i in input_files if i not in ds and i not in fs]
if invalid:
raise Exception(f"Invalid input files or directories: {', '.join(invalid)}")
documents: list[Document] = []
if ds:
for d in ds:
reader = SimpleDirectoryReader(
input_dir=d,
recursive=True,
exclude_hidden=False,
)
documents.extend(reader.load_data(show_progress=True))
if fs:
reader = SimpleDirectoryReader(
input_files=fs,
exclude_hidden=False,
)
documents.extend(reader.load_data(show_progress=True))
index = VectorStoreIndex.from_documents(
documents=documents,
storage_context=storage_context,
embedding=embed_model,
)
retriever = index.as_retriever(similarity_top_k=2)
from llama_index.core.query_engine import RetrieverQueryEngine
query_engine = RetrieverQueryEngine.from_args(retriever, llm=llm)
response = query_engine.query(args.prompt)
print(response)
if __name__ == "__main__":
main()