-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
118 lines (85 loc) · 3.47 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import streamlit as st
from io import StringIO
from langchain.chains import ConversationalRetrievalChain
from langchain_openai import ChatOpenAI
from langchain_community.document_loaders import DirectoryLoader
from langchain_openai import OpenAIEmbeddings
from langchain.indexes import VectorstoreIndexCreator
from langchain.memory import ConversationBufferMemory
gpt_model = "gpt-3.5-turbo-1106"
# set local docs for langchain
chat_history = None
memory = None
loader = None
index = None
retriever = None
llm = None
api_key = None
# path to database
infofile = "./database/data.txt"
# assistant prompt
pre_prompt = "You are a friendly and helpful teaching assistant called Cousin. You explain concepts in great depth using simple terms."
# titulo da pagina
st.markdown("<h1 style='text-align: center; color: white;'>Marta-GPT v0.0.1</h1>", unsafe_allow_html=True)
def setup_langchain():
global chat_history, memory, loader, index, llm, retriever, api_key
chat_history = []
memory = ConversationBufferMemory(memory_key="chat_history",return_messages=True)
# set local docs for langchain
embeddings = OpenAIEmbeddings(api_key = api_key)
loader = DirectoryLoader("database/", glob= "**/*.txt")
index = VectorstoreIndexCreator(embedding = embeddings).from_loaders([loader])
#set up chain params:
llm = ChatOpenAI(model = gpt_model, api_key = api_key, temperature = 1, max_tokens = 128)
retriever = index.vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 2, "score_threshold": 1, "fetch_k": 16})
def save_data(data: str) -> None:
"""
Receives User data, to be embedded to the model
"""
try:
with open(infofile, 'a') as file:
file.write(data)
print("Successfully saved data.")
except Exception as e:
print(f"Error {e} has occurred")
def marta(question:str) -> None:
#recieves prompt from user, and returns answer
chain = ConversationalRetrievalChain.from_llm(
llm = llm,
retriever = retriever,
memory = memory,
)
result = chain.invoke({"question": question, "chat_history": chat_history})
chat_history.append((question, result['answer']))
return result['answer'].lower()
# sidebar
with st.sidebar:
st.header("Provide a valid OpenAI API key🗝")
while api_key is None:
api_key = st.sidebar.text_input("your key:", type = "password")
st.header("Provide data files with relevant info📄")
upload = st.file_uploader("Upload a txt file")
if upload is not None:
stringio = StringIO(upload.getvalue().decode("utf-8"))
datafile = stringio.read()
save_data(datafile) # save data from file in path database
setup_langchain()
# Store LLM generated responses
if "messages" not in st.session_state.keys():
st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
prompt = st.chat_input()
if prompt is not None:
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.write(prompt)
# mandar a questao e receber resposta do langchain
answer = marta(prompt)
#mostrar resposta
with st.chat_message("assistant"):
st.write(answer)
message = {"role": "assistant", "content": answer}
st.session_state.messages.append(message)