Skip to content

Commit

Permalink
Merge pull request #4 from sfc-gh-jcarroll/session-state-index
Browse files Browse the repository at this point in the history
  • Loading branch information
rlancemartin authored Jul 28, 2023
2 parents 1b6f3d5 + 6233afc commit fd09337
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions web_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.retrievers.web_research import WebResearchRetriever

@st.cache_resource
st.set_page_config(page_title="Interweb Explorer", page_icon="🌐")

def settings():

# Vectorstore
Expand All @@ -26,10 +27,10 @@ def settings():

# Initialize
web_retriever = WebResearchRetriever.from_llm(
vectorstore=vectorstore_public,
llm=llm,
search=search,
num_search_results=3
vectorstore=vectorstore_public,
llm=llm,
search=search,
num_search_results=3
)

return web_retriever, llm
Expand Down Expand Up @@ -62,10 +63,13 @@ def on_retriever_end(self, documents, **kwargs):
st.sidebar.image("img/ai.png")
st.header("`Interweb Explorer`")
st.info("`I am an AI that can answer questions by exploring, reading, and summarizing web pages."
"I can be configured to use different moddes: public API or private (no data sharing).`")
"I can be configured to use different modes: public API or private (no data sharing).`")

# Make retriever and llm
web_retriever, llm = settings()
if 'retriever' not in st.session_state:
st.session_state['retriever'], st.session_state['llm'] = settings()
web_retriever = st.session_state.retriever
llm = st.session_state.llm

# User input
question = st.text_input("`Ask a question:`")
Expand All @@ -76,14 +80,12 @@ def on_retriever_end(self, documents, **kwargs):
import logging
logging.basicConfig()
logging.getLogger("langchain.retrievers.web_research").setLevel(logging.INFO)
qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm,
retriever=web_retriever)

qa_chain = RetrievalQAWithSourcesChain.from_chain_type(llm, retriever=web_retriever)

# Write answer and sources
retrieval_streamer_cb = PrintRetrievalHandler(st.container())
answer = st.empty()
stream_handler = StreamHandler(answer, initial_text="`Answer:`\n\n")
result = qa_chain({"question": question},callbacks=[retrieval_streamer_cb, stream_handler])
answer.info('`Answer:`\n\n' + result['answer'])
st.info('`Sources:`\n\n' + result['sources'])

0 comments on commit fd09337

Please sign in to comment.