-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp_chat.py
138 lines (128 loc) · 4.37 KB
/
app_chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import time
import datetime
import os
import joblib
import streamlit as st
from decouple import config
import google.generativeai as genai
GOOGLE_API_KEY= config("GOOGLE_API_KEY")
genai.configure(api_key=GOOGLE_API_KEY)
new_chat_id = f'{datetime.datetime.now().strftime("%d %m %Y %H:%M")}'
MODEL_ROLE = 'ai'
AI_AVATAR_ICON = '🤖'
# Create a data/ folder if it doesn't already exist
try:
os.mkdir('data/')
except:
# data/ folder already exists
pass
# Load past chats (if available)
try:
past_chats: dict = joblib.load('data/past_chats_list')
except:
past_chats = {}
# Sidebar allows a list of past chats
with st.sidebar:
st.write('# Chat History')
if st.session_state.get('chat_id') is None:
st.session_state.chat_id = st.selectbox(
label='Check past chat history',
options=[new_chat_id] + list(past_chats.keys()),
format_func=lambda x: past_chats.get(x, 'New Chat'),
placeholder='_',
)
else:
# This will happen the first time AI response comes in
st.session_state.chat_id = st.selectbox(
label='Pick a past chat',
options=[new_chat_id, st.session_state.chat_id] + list(past_chats.keys()),
index=1,
format_func=lambda x: past_chats.get(x, 'New Chat' if x != st.session_state.chat_id else st.session_state.chat_title),
placeholder='_',
)
# Save new chats after a message has been sent to AI
# TODO: Give user a chance to name chat
st.session_state.chat_title = f'ChatSession-{st.session_state.chat_id}'
st.write('# can we chat?')
# Chat history (allows to ask multiple questions)
try:
st.session_state.messages = joblib.load(
f'data/{st.session_state.chat_id}-st_messages'
)
st.session_state.gemini_history = joblib.load(
f'data/{st.session_state.chat_id}-gemini_messages'
)
print('old cache')
except:
st.session_state.messages = []
st.session_state.gemini_history = []
print('new_cache made')
st.session_state.model = genai.GenerativeModel('gemini-pro')
st.session_state.chat = st.session_state.model.start_chat(
history=st.session_state.gemini_history,
)
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(
name=message['role'],
avatar=message.get('avatar'),
):
st.markdown(message['content'])
# React to user input
if prompt := st.chat_input('Your message here...'):
# Save this as a chat for later
if st.session_state.chat_id not in past_chats.keys():
past_chats[st.session_state.chat_id] = st.session_state.chat_title
joblib.dump(past_chats, 'data/past_chats_list')
# Display user message in chat message container
with st.chat_message('user'):
st.markdown(prompt)
# Add user message to chat history
st.session_state.messages.append(
dict(
role='user',
content=prompt,
)
)
## Send message to AI
response = st.session_state.chat.send_message(
prompt,
stream=True,
)
# Display assistant response in chat message container
with st.chat_message(
name=MODEL_ROLE,
avatar=AI_AVATAR_ICON,
):
message_placeholder = st.empty()
full_response = ''
assistant_response = response
# Streams in a chunk at a time
for chunk in response:
# Simulate stream of chunk
# TODO: Chunk missing `text` if API stops mid-stream ("safety"?)
for ch in chunk.text.split(' '):
full_response += ch + ' '
time.sleep(0.05)
# Rewrites with a cursor at end
message_placeholder.write(full_response + '▌')
# Write full message with placeholder
message_placeholder.write(full_response)
# Add assistant response to chat history
st.session_state.messages.append(
dict(
role=MODEL_ROLE,
content=st.session_state.chat.history[-1].parts[0].text,
avatar=AI_AVATAR_ICON,
)
)
st.session_state.gemini_history = st.session_state.chat.history
# Save to file
joblib.dump(
st.session_state.messages,
f'data/{st.session_state.chat_id}-st_messages',
)
joblib.dump(
st.session_state.gemini_history,
f'data/{st.session_state.chat_id}-gemini_messages',
)