Skip to content

Commit

Permalink
Merge pull request #41 from openml-labs/streaming
Browse files Browse the repository at this point in the history
Streaming, chat improvements, download button
  • Loading branch information
SubhadityaMukherjee authored Aug 26, 2024
2 parents c8cd3ce + 65386a1 commit 04e9828
Show file tree
Hide file tree
Showing 17 changed files with 416 additions and 320 deletions.
2 changes: 1 addition & 1 deletion backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.0"
from .modules.rag_llm import *
from .modules.vector_store_utils import *
from .modules.utils import *
from .modules.vector_store_utils import *
1 change: 0 additions & 1 deletion backend/modules/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import pickle

# from pqdm.processes import pqdm
from typing import Sequence, Tuple, Union

Expand Down
5 changes: 2 additions & 3 deletions backend/modules/results_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
import pandas as pd
from flashrank import Ranker, RerankRequest
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain_community.document_transformers.long_context_reorder import (
LongContextReorder,
)
from langchain_community.document_transformers.long_context_reorder import \
LongContextReorder
from langchain_core.documents import BaseDocumentTransformer, Document
from tqdm import tqdm

Expand Down
6 changes: 4 additions & 2 deletions docs/Rag Pipeline/Developer Tutorials/change_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
# - How would you use a different embedding and llm model?

from __future__ import annotations
from langchain_community.cache import SQLiteCache

import os
import sys

import chromadb
from langchain_community.cache import SQLiteCache

from backend.modules.utils import load_config_and_device
from backend.modules.rag_llm import QASetup
from backend.modules.utils import load_config_and_device

# ## Initial config

Expand Down
23 changes: 18 additions & 5 deletions documentation_bot/documentation_query.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from utils import ChromaStore, Crawler
import os
import uuid

from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
from httpx import ConnectTimeout
from tenacity import retry, retry_if_exception_type, stop_after_attempt
import uuid
from utils import ChromaStore, Crawler

#TODO : make this into a separate thing using config
# TODO : make this into a separate thing using config
recrawl_websites = False

crawled_files_data_path = "../data/crawler/crawled_data.csv"
Expand Down Expand Up @@ -41,6 +42,17 @@
app = FastAPI()
session_id = str(uuid.uuid4())


def stream_response(response):
for line in response:
try:
yield str(line["answer"])
except GeneratorExit:
break
except:
yield ""


@app.get("/documentationquery/{query}", response_class=JSONResponse)
@retry(stop=stop_after_attempt(3), retry=retry_if_exception_type(ConnectTimeout))
async def get_documentation_query(query: str):
Expand All @@ -49,4 +61,5 @@ async def get_documentation_query(query: str):

chroma_store.setup_inference(session_id)
response = chroma_store.openml_page_search(input=query)
return JSONResponse(content=response)
# return JSONResponse(content=response)
return StreamingResponse(stream_response(response), media_type="text/event-stream")
41 changes: 26 additions & 15 deletions documentation_bot/utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
import requests
from bs4 import BeautifulSoup
import csv
import os
from urllib.parse import urljoin

import pandas as pd
import requests
import torch
from langchain.text_splitter import RecursiveCharacterTextSplitter
from bs4 import BeautifulSoup
from langchain.chains import (create_history_aware_retriever,
create_retrieval_chain)
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.document_loaders import DataFrameLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores.chroma import Chroma
from urllib.parse import urljoin
from tqdm.auto import tqdm

from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory

from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_ollama import ChatOllama
from tqdm.auto import tqdm


def find_device() -> str:
"""
Expand All @@ -30,14 +31,18 @@ def find_device() -> str:
return "mps"
return "cpu"

store = {} #TODO : change this to a nicer way of storing the chat history

store = {} # TODO : change this to a nicer way of storing the chat history


def get_session_history(session_id: str) -> BaseChatMessageHistory:
# print("this is the session id", session_id)
if session_id not in store:
store[session_id] = ChatMessageHistory()
# print("this is the store id", store[session_id])
return store[session_id]


class Crawler:
"""
Description: This class is used to crawl the OpenML website and gather both code and general information for a bot.
Expand Down Expand Up @@ -279,7 +284,7 @@ def setup_inference(self, session_id: str) -> None:
self.store = {}
self.session_id = session_id

def openml_page_search(self, input: str) -> str:
def openml_page_search(self, input: str):

vectorstore = Chroma(
persist_directory=self.chroma_file_path,
Expand Down Expand Up @@ -326,10 +331,16 @@ def openml_page_search(self, input: str) -> str:
output_messages_key="answer",
)

answer = conversational_rag_chain.invoke(
# answer = conversational_rag_chain.invoke(
# {"input": f"{input}"},
# config={
# "configurable": {"session_id": self.session_id}
# }, # constructs a key "abc123" in `store`.
# )["answer"]
answer = conversational_rag_chain.stream(
{"input": f"{input}"},
config={
"configurable": {"session_id": self.session_id}
}, # constructs a key "abc123" in `store`.
)["answer"]
)
return answer
3 changes: 2 additions & 1 deletion evaluation/experiments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from training_utils import *
from concurrent.futures import ThreadPoolExecutor, as_completed

from tqdm.auto import tqdm
from training_utils import *


def exp_0(process_query_elastic_search, eval_path, query_key_dict):
Expand Down
8 changes: 5 additions & 3 deletions evaluation/run_all_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from __future__ import annotations

import json
from pathlib import Path
import os
from backend.modules.utils import load_config_and_device
from training_utils import *
from pathlib import Path

from experiments import *
from training_utils import *

from backend.modules.utils import load_config_and_device

if __name__ == "__main__":
# %%
Expand Down
5 changes: 2 additions & 3 deletions evaluation/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
# change the path to the backend directory
sys.path.append(os.path.join(os.path.dirname("."), "../backend/"))
import requests
# add modules from ui_utils
from tqdm.auto import tqdm

# %%
from backend.modules.rag_llm import *
from backend.modules.results_gen import *

# add modules from ui_utils
from tqdm.auto import tqdm
from frontend.ui_utils import *


Expand Down
28 changes: 3 additions & 25 deletions frontend/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,15 @@
from ui_utils import *

# Streamlit Chat Interface
logo = "images/favicon.ico"
page_title = "OpenML : A worldwide machine learning lab"
info = """
<p style='text-align: center; color: white;'>Machine learning research should be easily accessible and reusable. OpenML is an open platform for sharing datasets, algorithms, and experiments - to learn how to learn better, together. <br>Ask me anything about OpenML or search for a dataset ... </p>
"""
chatbot_display = "How do I do X using OpenML? / Find me a dataset about Y"

chatbot_max_chars = 500

st.set_page_config(page_title=page_title, page_icon=logo)
st.set_page_config(page_title=page_title)
st.title("OpenML AI Search")
# message_box = st.container()

with st.spinner("Loading Required Data"):
config_path = Path("../backend/config.json")
ui_loader = UILoader(config_path)

# container for company description and logo
with st.sidebar:
query_type = st.radio(
"Select Query Type", ["General Query", "Dataset", "Flow"], key="query_type_2"
)

user_input = st.chat_input(placeholder=chatbot_display, max_chars=chatbot_max_chars)
col1, col2 = st.columns([1, 4])
with col1:
st.image(logo, width=100)
with col2:
st.markdown(
info,
unsafe_allow_html=True,
)
ui_loader.create_chat_interface(user_input=None)
if user_input:
ui_loader.create_chat_interface(user_input, query_type=query_type)
ui_loader.generate_complete_ui()
Loading

0 comments on commit 04e9828

Please sign in to comment.