-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #840 from Undertone0809/v1.18.0/optimize-streamlit…
…-chatbot pref: optimize streamlit chatbot
- Loading branch information
Showing
3 changed files
with
74 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,71 @@ | ||
import pne | ||
import streamlit as st | ||
|
||
# Create a sidebar to place the user parameter configuration | ||
with st.sidebar: | ||
model_options = [ | ||
"openai/gpt-4o", | ||
"openai/gpt-4-turbo", | ||
"deepseek/deepseek-chat", | ||
"zhipu/glm-4", | ||
"ollama/llama2", | ||
"groq/llama-3.1-70b-versatile", | ||
"claude-3-5-sonnet-20240620", | ||
] | ||
|
||
# Add a placeholder for custom model name entry | ||
model_options.insert(0, "Custom Model") | ||
def main(): | ||
with st.sidebar: | ||
model_options = [ | ||
"openai/gpt-4o", | ||
"openai/gpt-4-turbo", | ||
"deepseek/deepseek-chat", | ||
"zhipu/glm-4", | ||
"ollama/llama2", | ||
"groq/llama-3.1-70b-versatile", | ||
"claude-3-5-sonnet-20240620", | ||
] | ||
|
||
selected_option = st.selectbox( | ||
label="Language Model Name", | ||
options=model_options, | ||
) | ||
# Add a placeholder for custom model name entry | ||
model_options.insert(0, "Custom Model") | ||
|
||
model_name = selected_option | ||
if selected_option == "Custom Model": | ||
# Show a text input field for custom model name when "Custom Model" is selected | ||
model_name = st.text_input( | ||
"Enter Custom Model Name", | ||
placeholder="Custom model name, eg: groq/llama3-70b-8192", | ||
help=( | ||
"For more details, please see " | ||
"[how to write model name?](https://www.promptulate.cn/#/other/how_to_write_model_name)" # noqa | ||
), | ||
selected_option = st.selectbox( | ||
label="Language Model Name", | ||
options=model_options, | ||
) | ||
api_key = st.text_input("API Key", key="provider_api_key", type="password") | ||
api_base = st.text_input("OpenAI Proxy URL (Optional)") | ||
|
||
# Set title | ||
st.title("💬 Chat") | ||
st.caption("🚀 Hi there! 👋 I am a simple chatbot by Promptulate to help you.") | ||
model_name = selected_option | ||
if selected_option == "Custom Model": | ||
model_name = st.text_input( | ||
"Enter Custom Model Name", | ||
placeholder="Custom model name, eg: groq/llama3-70b-8192", | ||
help=( | ||
"For more details, please see " | ||
"[how to write model name?](https://www.promptulate.cn/#/other/how_to_write_model_name)" # noqa | ||
), | ||
) | ||
api_key = st.text_input("API Key", key="provider_api_key", type="password") | ||
api_base = st.text_input("OpenAI Proxy URL (Optional)") | ||
|
||
# Determine whether to initialize the message variable | ||
# otherwise initialize a message dictionary | ||
if "messages" not in st.session_state: | ||
st.session_state["messages"] = [ | ||
{"role": "assistant", "content": "How can I help you?"} | ||
] | ||
st.title("💬 Chat") | ||
st.caption("🚀 Hi there! 👋 I am a simple chatbot by Promptulate to help you.") | ||
|
||
# Traverse messages in session state | ||
for msg in st.session_state.messages: | ||
st.chat_message(msg["role"]).write(msg["content"]) | ||
if "messages" not in st.session_state: | ||
st.session_state["messages"] = [ | ||
{"role": "assistant", "content": "How can I help you?"} | ||
] | ||
|
||
# User input | ||
if prompt := st.chat_input(): | ||
if not api_key: | ||
st.info("Please add your API key to continue.") | ||
st.stop() | ||
for msg in st.session_state.messages: | ||
st.chat_message(msg["role"]).write(msg["content"]) | ||
|
||
# Add the message entered by the user to the list of messages in the session state | ||
st.session_state.messages.append({"role": "user", "content": prompt}) | ||
# Display in the chat interface | ||
st.chat_message("user").write(prompt) | ||
if prompt := st.chat_input("How can I help you?"): | ||
if not api_key: | ||
st.info("Please add your API key to continue.") | ||
st.stop() | ||
|
||
response: str = pne.chat( | ||
model=model_name, | ||
stream=True, | ||
messages=prompt, | ||
model_config={"api_base": api_base, "api_key": api_key}, | ||
) | ||
st.session_state.messages.append({"role": "user", "content": prompt}) | ||
|
||
st.session_state.messages.append({"role": "assistant", "content": "start"}) | ||
st.chat_message("assistant").write_stream(response) | ||
with st.chat_message("user"): | ||
st.markdown(prompt) | ||
|
||
with st.chat_message("assistant"): | ||
stream = pne.chat( | ||
model=model_name, | ||
stream=True, | ||
messages=st.session_state.messages, | ||
model_config={"api_base": api_base, "api_key": api_key}, | ||
) | ||
response = st.write_stream(stream) | ||
st.session_state.messages.append({"role": "assistant", "content": response}) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters