From 979d246be23bf7ecb341f80cfa894bbb99802de9 Mon Sep 17 00:00:00 2001 From: SubhadityaMukherjee Date: Mon, 26 Aug 2024 12:22:37 +0200 Subject: [PATCH 1/4] fixed streaming and better chat interface --- documentation_bot/documentation_query.py | 14 +++- documentation_bot/utils.py | 12 +++- frontend/ui.py | 24 ++++--- frontend/ui_utils.py | 92 ++++++++++++++++-------- 4 files changed, 100 insertions(+), 42 deletions(-) diff --git a/documentation_bot/documentation_query.py b/documentation_bot/documentation_query.py index c45aae2..9ce658f 100644 --- a/documentation_bot/documentation_query.py +++ b/documentation_bot/documentation_query.py @@ -5,6 +5,7 @@ from httpx import ConnectTimeout from tenacity import retry, retry_if_exception_type, stop_after_attempt import uuid +from fastapi.responses import StreamingResponse #TODO : make this into a separate thing using config recrawl_websites = False @@ -41,6 +42,15 @@ app = FastAPI() session_id = str(uuid.uuid4()) +def stream_response(response): + for line in response: + try: + yield str(line["answer"]) + except GeneratorExit: + break + except: + yield "" + @app.get("/documentationquery/{query}", response_class=JSONResponse) @retry(stop=stop_after_attempt(3), retry=retry_if_exception_type(ConnectTimeout)) async def get_documentation_query(query: str): @@ -49,4 +59,6 @@ async def get_documentation_query(query: str): chroma_store.setup_inference(session_id) response = chroma_store.openml_page_search(input=query) - return JSONResponse(content=response) + # return JSONResponse(content=response) + return StreamingResponse(stream_response(response), media_type="text/event-stream") + diff --git a/documentation_bot/utils.py b/documentation_bot/utils.py index 6a1459b..a430d17 100644 --- a/documentation_bot/utils.py +++ b/documentation_bot/utils.py @@ -279,7 +279,7 @@ def setup_inference(self, session_id: str) -> None: self.store = {} self.session_id = session_id - def openml_page_search(self, input: str) -> str: + def openml_page_search(self, input: str): vectorstore = Chroma( persist_directory=self.chroma_file_path, @@ -326,10 +326,16 @@ def openml_page_search(self, input: str) -> str: output_messages_key="answer", ) - answer = conversational_rag_chain.invoke( + # answer = conversational_rag_chain.invoke( + # {"input": f"{input}"}, + # config={ + # "configurable": {"session_id": self.session_id} + # }, # constructs a key "abc123" in `store`. + # )["answer"] + answer = conversational_rag_chain.stream( {"input": f"{input}"}, config={ "configurable": {"session_id": self.session_id} }, # constructs a key "abc123" in `store`. - )["answer"] + ) return answer diff --git a/frontend/ui.py b/frontend/ui.py index afa08ed..ed6ffa8 100644 --- a/frontend/ui.py +++ b/frontend/ui.py @@ -21,12 +21,10 @@ ui_loader = UILoader(config_path) # container for company description and logo -with st.sidebar: - query_type = st.radio( - "Select Query Type", ["General Query", "Dataset", "Flow"], key="query_type_2" - ) - -user_input = st.chat_input(placeholder=chatbot_display, max_chars=chatbot_max_chars) +# with st.sidebar: +# query_type = st.radio( +# "Select Query Type", ["General Query", "Dataset", "Flow"], key="query_type_2" +# ) col1, col2 = st.columns([1, 4]) with col1: st.image(logo, width=100) @@ -35,6 +33,14 @@ info, unsafe_allow_html=True, ) -ui_loader.create_chat_interface(user_input=None) -if user_input: - ui_loader.create_chat_interface(user_input, query_type=query_type) +chat_container = st.container() +with chat_container: + query_type = st.radio( + "Select Query Type", ["General Query", "Dataset", "Flow"], key="query_type_2" + ) + user_input = st.chat_input(placeholder=chatbot_display, max_chars=chatbot_max_chars) + + + ui_loader.create_chat_interface(user_input=None) + if user_input: + ui_loader.create_chat_interface(user_input, query_type=query_type) diff --git a/frontend/ui_utils.py b/frontend/ui_utils.py index 7826514..d8139d8 100644 --- a/frontend/ui_utils.py +++ b/frontend/ui_utils.py @@ -139,21 +139,21 @@ def fetch_llm_response(self, query: str): ).json() return self.llm_response - def fetch_documentation_query(self, query: str): - """ - Description: Fetch the response for a general or documentation or code query from the LLM service as a JSON - """ - documentation_response_path = self.paths["documentation_query"] - try: - self.documentation_response = requests.get( - f"{documentation_response_path['docker']}{query}", - json={"query": query}, - ).json() - except: - self.documentation_response = requests.get( - f"{documentation_response_path['local']}{query}", - json={"query": query}, - ).json() + # def fetch_documentation_query(self, query: str): + # """ + # Description: Fetch the response for a general or documentation or code query from the LLM service as a JSON + # """ + # documentation_response_path = self.paths["documentation_query"] + # try: + # self.documentation_response = requests.get( + # f"{documentation_response_path['docker']}{query}", + # json={"query": query}, + # ).json() + # except: + # self.documentation_response = requests.get( + # f"{documentation_response_path['local']}{query}", + # json={"query": query}, + # ).json() def fetch_structured_query(self, query_type: str, query: str): """ @@ -365,6 +365,7 @@ def __init__(self, config_path): # defaults self.query_type = "Dataset" self.llm_filter = False + self.paths = self.load_paths() if "messages" not in st.session_state: st.session_state.messages = [] @@ -394,19 +395,35 @@ def create_chat_interface(self, user_input, query_type=None): st.session_state.messages.append({"role": "user", "content": user_input}) with st.spinner("Waiting for results..."): results = self.process_query_chat(user_input) - - st.session_state.messages.append( - {"role": "OpenML Agent", "content": results} - ) - - # Display chat history - for message in st.session_state.messages: - if message["role"] == "user": - with st.chat_message(name = "user"): - self.display_results(message["content"], "user") + + if not self.query_type == "General Query": + st.session_state.messages.append( + {"role": "OpenML Agent", "content": results} + ) else: - with st.chat_message(name = "ai"): - self.display_results(message["content"], "ai") + with st.spinner("Fetching results..."): + with requests.get(results, stream=True) as r: + resp_contain = st.empty() + streamed_response = "" + for chunk in r.iter_content(chunk_size=1024): + if chunk: + streamed_response += chunk.decode("utf-8") + resp_contain.markdown(streamed_response) + resp_contain.empty() + st.session_state.messages.append( + {"role": "OpenML Agent", "content": streamed_response} + ) + + # Display chat history + for message in st.session_state.messages: + if query_type == "General Query": + pass + if message["role"] == "user": + with st.chat_message(name = "user"): + self.display_results(message["content"], "user") + else: + with st.chat_message(name = "ai"): + self.display_results(message["content"], "ai") def display_results(self,initial_response, role): """ @@ -468,5 +485,22 @@ def process_query_chat(self, query): results = response_parser.parse_and_update_response(self.data_metadata) return results elif self.query_type == "General Query": - response_parser.fetch_documentation_query(query) - return response_parser.documentation_response + # response_parser.fetch_documentation_query(query) + # return response_parser.documentation_response + documentation_response_path = self.paths["documentation_query"]["local"] + query + # with requests.get(documentation_response_path, stream=True) as r: + # resp_contain = st.empty() + # streamed_response = "" + # for chunk in r.iter_content(chunk_size=1024): + # if chunk: + # streamed_response += chunk.decode("utf-8") + # resp_contain.markdown(streamed_response) + # return requests.get(documentation_response_path, stream=True) + return documentation_response_path + + def load_paths(self): + """ + Description: Load paths from paths.json + """ + with open("paths.json", "r") as file: + return json.load(file) From 35037c4437425b11c18ac3b85e1f8767f4e411b7 Mon Sep 17 00:00:00 2001 From: SubhadityaMukherjee Date: Mon, 26 Aug 2024 13:43:14 +0200 Subject: [PATCH 2/4] better chat interface --- frontend/ui.py | 9 +++++---- frontend/ui_utils.py | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/frontend/ui.py b/frontend/ui.py index ed6ffa8..9907aaf 100644 --- a/frontend/ui.py +++ b/frontend/ui.py @@ -33,13 +33,14 @@ info, unsafe_allow_html=True, ) + chat_container = st.container() with chat_container: - query_type = st.radio( - "Select Query Type", ["General Query", "Dataset", "Flow"], key="query_type_2" - ) - user_input = st.chat_input(placeholder=chatbot_display, max_chars=chatbot_max_chars) + with st.form(key="chat_form"): + user_input = st.text_input(label=chatbot_display) + query_type = st.selectbox("Select Query Type", ["General Query", "Dataset", "Flow"]) + submit_button = st.form_submit_button(label="Submit") ui_loader.create_chat_interface(user_input=None) if user_input: diff --git a/frontend/ui_utils.py b/frontend/ui_utils.py index d8139d8..76004ba 100644 --- a/frontend/ui_utils.py +++ b/frontend/ui_utils.py @@ -378,7 +378,6 @@ def __init__(self, config_path): # return st.chat_input( # self.chatbot_display, max_chars=self.chatbot_input_max_chars # ) - def create_chat_interface(self, user_input, query_type=None): """ Description: Create the chat interface and display the chat history and results. Show the user input and the response from the OpenML Agent. From 0d36a3ba50edab7f814d3a69a06a1597c41197cd Mon Sep 17 00:00:00 2001 From: SubhadityaMukherjee Date: Mon, 26 Aug 2024 14:17:10 +0200 Subject: [PATCH 3/4] added download chat button --- frontend/ui.py | 11 +++-- frontend/ui_utils.py | 110 +++++++++++++++++++++++++------------------ 2 files changed, 70 insertions(+), 51 deletions(-) diff --git a/frontend/ui.py b/frontend/ui.py index 9907aaf..a072b02 100644 --- a/frontend/ui.py +++ b/frontend/ui.py @@ -21,10 +21,6 @@ ui_loader = UILoader(config_path) # container for company description and logo -# with st.sidebar: -# query_type = st.radio( -# "Select Query Type", ["General Query", "Dataset", "Flow"], key="query_type_2" -# ) col1, col2 = st.columns([1, 4]) with col1: st.image(logo, width=100) @@ -35,11 +31,16 @@ ) chat_container = st.container() + + + with chat_container: with st.form(key="chat_form"): user_input = st.text_input(label=chatbot_display) - query_type = st.selectbox("Select Query Type", ["General Query", "Dataset", "Flow"]) + query_type = st.selectbox( + "Select Query Type", ["General Query", "Dataset", "Flow"] + ) submit_button = st.form_submit_button(label="Submit") ui_loader.create_chat_interface(user_input=None) diff --git a/frontend/ui_utils.py b/frontend/ui_utils.py index 76004ba..1e46d54 100644 --- a/frontend/ui_utils.py +++ b/frontend/ui_utils.py @@ -13,7 +13,6 @@ from structured_query.chroma_store_utilis import * - def feedback_cb(): """ Description: Callback function to save feedback to a file @@ -37,7 +36,6 @@ def feedback_cb(): json.dump(data, file, indent=4) - class LLMResponseParser: """ Description: Parse the response from the LLM service and update the columns based on the response. @@ -138,22 +136,6 @@ def fetch_llm_response(self, query: str): f"{llm_response_path['local']}{query}" ).json() return self.llm_response - - # def fetch_documentation_query(self, query: str): - # """ - # Description: Fetch the response for a general or documentation or code query from the LLM service as a JSON - # """ - # documentation_response_path = self.paths["documentation_query"] - # try: - # self.documentation_response = requests.get( - # f"{documentation_response_path['docker']}{query}", - # json={"query": query}, - # ).json() - # except: - # self.documentation_response = requests.get( - # f"{documentation_response_path['local']}{query}", - # json={"query": query}, - # ).json() def fetch_structured_query(self, query_type: str, query: str): """ @@ -175,7 +157,6 @@ def fetch_structured_query(self, query_type: str, query: str): f"{structured_response_path['local']}{query}", json={"query": query}, ).json() - # except (requests.exceptions.RequestException, json.JSONDecodeError) as e: except Exception as e: # Print the error for debugging purposes print(f"Error occurred while fetching from local endpoint: {e}") @@ -228,9 +209,11 @@ def parse_and_update_response(self, metadata: pd.DataFrame): - Metadata is filtered based on the rag response first and then by the Query parsing LLM - self.apply_llm_before_rag == False - Metadata is filtered based by the Query parsing LLM first and the rag response second - - in case structured_query == true, take results are applying data filters. + - in case structured_query == true, take results are applying data filters. """ - if (self.apply_llm_before_rag is None or self.llm_response is None) and not config["structured_query"]: + if ( + self.apply_llm_before_rag is None or self.llm_response is None + ) and not config["structured_query"]: print("No LLM filter.") # print(self.rag_response, flush=True) filtered_metadata = metadata[ @@ -249,7 +232,9 @@ def parse_and_update_response(self, metadata: pd.DataFrame): # if no llm response is required, return the initial response return filtered_metadata - elif (self.rag_response is not None and self.llm_response is not None) and not config["structured_query"]: + elif ( + self.rag_response is not None and self.llm_response is not None + ) and not config["structured_query"]: if not self.apply_llm_before_rag: print("RAG before LLM filter.") filtered_metadata = metadata[ @@ -288,27 +273,34 @@ def parse_and_update_response(self, metadata: pd.DataFrame): elif ( self.rag_response is not None and self.structured_query_response is not None - ): + ): col_name = [ "status", "NumberOfClasses", "NumberOfFeatures", "NumberOfInstances", ] - print(self.structured_query_response) # Only for debugging. Comment later. - if self.structured_query_response[0] is not None and isinstance(self.structured_query_response[1], dict): + print(self.structured_query_response) # Only for debugging. Comment later. + if self.structured_query_response[0] is not None and isinstance( + self.structured_query_response[1], dict + ): # Safely attempt to access the "filter" key in the first element - - if self.structured_query_response[0].get("filter", None) and self.database_filtered: + + if ( + self.structured_query_response[0].get("filter", None) + and self.database_filtered + ): filtered_metadata = metadata[ metadata["did"].isin(self.database_filtered) ] print("Showing database filtered data") else: filtered_metadata = metadata[ - metadata["did"].isin(self.rag_response["initial_response"]) + metadata["did"].isin(self.rag_response["initial_response"]) ] - print("Showing only rag response as filter is empty or none of the rag data satisfies filter conditions.") + print( + "Showing only rag response as filter is empty or none of the rag data satisfies filter conditions." + ) filtered_metadata["did"] = pd.Categorical( filtered_metadata["did"], categories=self.rag_response["initial_response"], @@ -317,7 +309,7 @@ def parse_and_update_response(self, metadata: pd.DataFrame): filtered_metadata = filtered_metadata.sort_values("did").reset_index( drop=True ) - + else: filtered_metadata = metadata[ metadata["did"].isin(self.rag_response["initial_response"]) @@ -386,16 +378,16 @@ def create_chat_interface(self, user_input, query_type=None): self.query_type = query_type # self.llm_filter = llm_filter if user_input is None: - with st.chat_message(name = "ai"): + with st.chat_message(name="ai"): st.write("OpenML Agent: ", "Hello! How can I help you today?") - + # Handle user input if user_input: st.session_state.messages.append({"role": "user", "content": user_input}) with st.spinner("Waiting for results..."): results = self.process_query_chat(user_input) - - if not self.query_type == "General Query": + + if not self.query_type == "General Query": st.session_state.messages.append( {"role": "OpenML Agent", "content": results} ) @@ -412,24 +404,40 @@ def create_chat_interface(self, user_input, query_type=None): st.session_state.messages.append( {"role": "OpenML Agent", "content": streamed_response} ) + + # reverse messages to show the latest message at the top + reversed_messages = [] + for index in range(0, len(st.session_state.messages),2): + reversed_messages.insert(0, st.session_state.messages[index]) + reversed_messages.insert(1, st.session_state.messages[index + 1]) # Display chat history - for message in st.session_state.messages: + for message in reversed_messages: if query_type == "General Query": pass if message["role"] == "user": - with st.chat_message(name = "user"): + with st.chat_message(name="user"): self.display_results(message["content"], "user") else: - with st.chat_message(name = "ai"): + with st.chat_message(name="ai"): self.display_results(message["content"], "ai") + self.create_download_button() + + @st.experimental_fragment() + def create_download_button(self): + data = "\n".join([str(message["content"]) for message in st.session_state.messages]) + st.download_button( + label="Download chat history", + data=data, + file_name="chat_history.txt", + ) - def display_results(self,initial_response, role): + def display_results(self, initial_response, role): """ Description: Display the results in a DataFrame """ # st.write("OpenML Agent: ") - + try: st.dataframe(initial_response) # self.message_box.chat_message(role).write(st.dataframe(initial_response)) @@ -454,22 +462,30 @@ def process_query_chat(self, query): response_parser.fetch_structured_query(self.query_type, query) try: # get rag response - # using original query instead of extracted topics. + # using original query instead of extracted topics. response_parser.fetch_rag_response( self.query_type, response_parser.structured_query_response[0]["query"], ) - + if response_parser.structured_query_response: - st.write("Detected Filter(s): ", json.dumps(response_parser.structured_query_response[0].get("filter", None))) + st.write( + "Detected Filter(s): ", + json.dumps( + response_parser.structured_query_response[0].get( + "filter", None + ) + ), + ) else: st.write("Detected Filter(s): ", None) # st.write("Detected Topics: ", response_parser.structured_query_response[0].get("query", None)) if response_parser.structured_query_response[1].get("filter"): - + with st.spinner("Applying LLM Detected Filter(s)..."): response_parser.database_filter( - response_parser.structured_query_response[1]["filter"], collec + response_parser.structured_query_response[1]["filter"], + collec, ) except: # fallback to RAG response @@ -486,7 +502,9 @@ def process_query_chat(self, query): elif self.query_type == "General Query": # response_parser.fetch_documentation_query(query) # return response_parser.documentation_response - documentation_response_path = self.paths["documentation_query"]["local"] + query + documentation_response_path = ( + self.paths["documentation_query"]["local"] + query + ) # with requests.get(documentation_response_path, stream=True) as r: # resp_contain = st.empty() # streamed_response = "" @@ -496,7 +514,7 @@ def process_query_chat(self, query): # resp_contain.markdown(streamed_response) # return requests.get(documentation_response_path, stream=True) return documentation_response_path - + def load_paths(self): """ Description: Load paths from paths.json From 65386a11f361f283b24ce1b4bf9ca674c056318b Mon Sep 17 00:00:00 2001 From: SubhadityaMukherjee Date: Mon, 26 Aug 2024 15:23:23 +0200 Subject: [PATCH 4/4] better UI, reformatted UI scripts --- backend/__init__.py | 2 +- backend/modules/metadata_utils.py | 1 - backend/modules/results_gen.py | 5 +- .../Developer Tutorials/change_model.py | 6 +- documentation_bot/documentation_query.py | 13 +- documentation_bot/utils.py | 29 +- evaluation/experiments.py | 3 +- evaluation/run_all_training.py | 8 +- evaluation/training_utils.py | 5 +- frontend/ui.py | 36 +- frontend/ui_utils.py | 444 +++++++++--------- structured_query/chroma_store_utilis.py | 10 +- .../llm_service_structured_query.py | 22 +- .../llm_service_structured_query_utils.py | 11 +- structured_query/structuring_query.py | 6 +- useful_scripts/test_selfquery.ipynb | 113 +++-- .../test_strutured_query_pipeline.ipynb | 3 - 17 files changed, 368 insertions(+), 349 deletions(-) diff --git a/backend/__init__.py b/backend/__init__.py index 4a02c7d..5c665c9 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -1,4 +1,4 @@ __version__ = "0.1.0" from .modules.rag_llm import * -from .modules.vector_store_utils import * from .modules.utils import * +from .modules.vector_store_utils import * diff --git a/backend/modules/metadata_utils.py b/backend/modules/metadata_utils.py index 00860d8..e9a9ff0 100644 --- a/backend/modules/metadata_utils.py +++ b/backend/modules/metadata_utils.py @@ -2,7 +2,6 @@ import os import pickle - # from pqdm.processes import pqdm from typing import Sequence, Tuple, Union diff --git a/backend/modules/results_gen.py b/backend/modules/results_gen.py index 9618be8..52ed937 100644 --- a/backend/modules/results_gen.py +++ b/backend/modules/results_gen.py @@ -9,9 +9,8 @@ import pandas as pd from flashrank import Ranker, RerankRequest from langchain.chains.retrieval_qa.base import RetrievalQA -from langchain_community.document_transformers.long_context_reorder import ( - LongContextReorder, -) +from langchain_community.document_transformers.long_context_reorder import \ + LongContextReorder from langchain_core.documents import BaseDocumentTransformer, Document from tqdm import tqdm diff --git a/docs/Rag Pipeline/Developer Tutorials/change_model.py b/docs/Rag Pipeline/Developer Tutorials/change_model.py index 01ba86f..fd8cc45 100644 --- a/docs/Rag Pipeline/Developer Tutorials/change_model.py +++ b/docs/Rag Pipeline/Developer Tutorials/change_model.py @@ -16,13 +16,15 @@ # - How would you use a different embedding and llm model? from __future__ import annotations -from langchain_community.cache import SQLiteCache + import os import sys + import chromadb +from langchain_community.cache import SQLiteCache -from backend.modules.utils import load_config_and_device from backend.modules.rag_llm import QASetup +from backend.modules.utils import load_config_and_device # ## Initial config diff --git a/documentation_bot/documentation_query.py b/documentation_bot/documentation_query.py index 9ce658f..b3ede0e 100644 --- a/documentation_bot/documentation_query.py +++ b/documentation_bot/documentation_query.py @@ -1,13 +1,13 @@ -from utils import ChromaStore, Crawler import os +import uuid + from fastapi import FastAPI -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse from httpx import ConnectTimeout from tenacity import retry, retry_if_exception_type, stop_after_attempt -import uuid -from fastapi.responses import StreamingResponse +from utils import ChromaStore, Crawler -#TODO : make this into a separate thing using config +# TODO : make this into a separate thing using config recrawl_websites = False crawled_files_data_path = "../data/crawler/crawled_data.csv" @@ -42,6 +42,7 @@ app = FastAPI() session_id = str(uuid.uuid4()) + def stream_response(response): for line in response: try: @@ -51,6 +52,7 @@ def stream_response(response): except: yield "" + @app.get("/documentationquery/{query}", response_class=JSONResponse) @retry(stop=stop_after_attempt(3), retry=retry_if_exception_type(ConnectTimeout)) async def get_documentation_query(query: str): @@ -61,4 +63,3 @@ async def get_documentation_query(query: str): response = chroma_store.openml_page_search(input=query) # return JSONResponse(content=response) return StreamingResponse(stream_response(response), media_type="text/event-stream") - diff --git a/documentation_bot/utils.py b/documentation_bot/utils.py index a430d17..9931649 100644 --- a/documentation_bot/utils.py +++ b/documentation_bot/utils.py @@ -1,24 +1,25 @@ -import requests -from bs4 import BeautifulSoup import csv import os +from urllib.parse import urljoin + import pandas as pd +import requests import torch -from langchain.text_splitter import RecursiveCharacterTextSplitter +from bs4 import BeautifulSoup +from langchain.chains import (create_history_aware_retriever, + create_retrieval_chain) +from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.document_loaders import DataFrameLoader +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.chat_message_histories import ChatMessageHistory from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain_community.vectorstores.chroma import Chroma -from urllib.parse import urljoin -from tqdm.auto import tqdm - -from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory - -from langchain_ollama import ChatOllama from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables.history import RunnableWithMessageHistory -from langchain.chains import create_history_aware_retriever, create_retrieval_chain -from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain_ollama import ChatOllama +from tqdm.auto import tqdm + def find_device() -> str: """ @@ -30,7 +31,10 @@ def find_device() -> str: return "mps" return "cpu" -store = {} #TODO : change this to a nicer way of storing the chat history + +store = {} # TODO : change this to a nicer way of storing the chat history + + def get_session_history(session_id: str) -> BaseChatMessageHistory: # print("this is the session id", session_id) if session_id not in store: @@ -38,6 +42,7 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: # print("this is the store id", store[session_id]) return store[session_id] + class Crawler: """ Description: This class is used to crawl the OpenML website and gather both code and general information for a bot. diff --git a/evaluation/experiments.py b/evaluation/experiments.py index cdbe1ac..247e34f 100644 --- a/evaluation/experiments.py +++ b/evaluation/experiments.py @@ -1,6 +1,7 @@ -from training_utils import * from concurrent.futures import ThreadPoolExecutor, as_completed + from tqdm.auto import tqdm +from training_utils import * def exp_0(process_query_elastic_search, eval_path, query_key_dict): diff --git a/evaluation/run_all_training.py b/evaluation/run_all_training.py index 772fbd1..eb5ce94 100644 --- a/evaluation/run_all_training.py +++ b/evaluation/run_all_training.py @@ -5,11 +5,13 @@ from __future__ import annotations import json -from pathlib import Path import os -from backend.modules.utils import load_config_and_device -from training_utils import * +from pathlib import Path + from experiments import * +from training_utils import * + +from backend.modules.utils import load_config_and_device if __name__ == "__main__": # %% diff --git a/evaluation/training_utils.py b/evaluation/training_utils.py index 90c47fe..28aa5f3 100644 --- a/evaluation/training_utils.py +++ b/evaluation/training_utils.py @@ -19,13 +19,12 @@ # change the path to the backend directory sys.path.append(os.path.join(os.path.dirname("."), "../backend/")) import requests +# add modules from ui_utils +from tqdm.auto import tqdm # %% from backend.modules.rag_llm import * from backend.modules.results_gen import * - -# add modules from ui_utils -from tqdm.auto import tqdm from frontend.ui_utils import * diff --git a/frontend/ui.py b/frontend/ui.py index a072b02..14dec5f 100644 --- a/frontend/ui.py +++ b/frontend/ui.py @@ -4,45 +4,15 @@ from ui_utils import * # Streamlit Chat Interface -logo = "images/favicon.ico" page_title = "OpenML : A worldwide machine learning lab" -info = """ -

Machine learning research should be easily accessible and reusable. OpenML is an open platform for sharing datasets, algorithms, and experiments - to learn how to learn better, together.
Ask me anything about OpenML or search for a dataset ...

- """ -chatbot_display = "How do I do X using OpenML? / Find me a dataset about Y" + chatbot_max_chars = 500 -st.set_page_config(page_title=page_title, page_icon=logo) +st.set_page_config(page_title=page_title) st.title("OpenML AI Search") # message_box = st.container() with st.spinner("Loading Required Data"): config_path = Path("../backend/config.json") ui_loader = UILoader(config_path) - -# container for company description and logo -col1, col2 = st.columns([1, 4]) -with col1: - st.image(logo, width=100) -with col2: - st.markdown( - info, - unsafe_allow_html=True, - ) - -chat_container = st.container() - - - -with chat_container: - - with st.form(key="chat_form"): - user_input = st.text_input(label=chatbot_display) - query_type = st.selectbox( - "Select Query Type", ["General Query", "Dataset", "Flow"] - ) - submit_button = st.form_submit_button(label="Submit") - - ui_loader.create_chat_interface(user_input=None) - if user_input: - ui_loader.create_chat_interface(user_input, query_type=query_type) + ui_loader.generate_complete_ui() diff --git a/frontend/ui_utils.py b/frontend/ui_utils.py index 1e46d54..fb543d7 100644 --- a/frontend/ui_utils.py +++ b/frontend/ui_utils.py @@ -1,41 +1,16 @@ import json -import os +import sys +from pathlib import Path +import pandas as pd import requests import streamlit as st from streamlit import session_state as ss -from langchain_community.query_constructors.chroma import ChromaTranslator -import pandas as pd -import sys -from pathlib import Path sys.path.append("../") from structured_query.chroma_store_utilis import * -def feedback_cb(): - """ - Description: Callback function to save feedback to a file - """ - file_path = "../data/feedback.json" - - if os.path.exists(file_path): - with open(file_path, "r") as file: - try: - data = json.load(file) - except json.JSONDecodeError: - data = [] - else: - data = [] - - # Append new feedback - data.append({"ss": ss.fb_k, "query": ss.query}) - - # Write updated content back to the file - with open(file_path, "w") as file: - json.dump(data, file, indent=4) - - class LLMResponseParser: """ Description: Parse the response from the LLM service and update the columns based on the response. @@ -191,15 +166,19 @@ def fetch_rag_response(self, query_type, query): f"{rag_response_path['local']}{query_type.lower()}/{query}", json={"query": query, "type": query_type.lower()}, ).json() + ordered_set = self._order_results() + self.rag_response["initial_response"] = ordered_set + + return self.rag_response + + def _order_results(self): doc_set = set() ordered_set = [] for docid in self.rag_response["initial_response"]: if docid not in doc_set: ordered_set.append(docid) doc_set.add(docid) - self.rag_response["initial_response"] = ordered_set - - return self.rag_response + return ordered_set def parse_and_update_response(self, metadata: pd.DataFrame): """ @@ -211,22 +190,10 @@ def parse_and_update_response(self, metadata: pd.DataFrame): - Metadata is filtered based by the Query parsing LLM first and the rag response second - in case structured_query == true, take results are applying data filters. """ - if ( - self.apply_llm_before_rag is None or self.llm_response is None - ) and not config["structured_query"]: + if self.apply_llm_before_rag is None or self.llm_response is None: print("No LLM filter.") # print(self.rag_response, flush=True) - filtered_metadata = metadata[ - metadata["did"].isin(self.rag_response["initial_response"]) - ] - filtered_metadata["did"] = pd.Categorical( - filtered_metadata["did"], - categories=self.rag_response["initial_response"], - ordered=True, - ) - filtered_metadata = filtered_metadata.sort_values("did").reset_index( - drop=True - ) + filtered_metadata = self._no_filter(metadata) # print(filtered_metadata) # if no llm response is required, return the initial response @@ -236,39 +203,14 @@ def parse_and_update_response(self, metadata: pd.DataFrame): self.rag_response is not None and self.llm_response is not None ) and not config["structured_query"]: if not self.apply_llm_before_rag: - print("RAG before LLM filter.") - filtered_metadata = metadata[ - metadata["did"].isin(self.rag_response["initial_response"]) - ] - filtered_metadata["did"] = pd.Categorical( - filtered_metadata["did"], - categories=self.rag_response["initial_response"], - ordered=True, - ) - filtered_metadata = filtered_metadata.sort_values("did").reset_index( - drop=True - ) - llm_parser = LLMResponseParser(self.llm_response) + filtered_metadata, llm_parser = self._rag_before_llm(metadata) if self.query_type.lower() == "dataset": llm_parser.get_attributes_from_response() return llm_parser.update_subset_cols(filtered_metadata) + elif self.apply_llm_before_rag: - print("LLM filter before RAG") - llm_parser = LLMResponseParser(self.llm_response) - llm_parser.get_attributes_from_response() - filtered_metadata = llm_parser.update_subset_cols(metadata) - filtered_metadata = filtered_metadata[ - metadata["did"].isin(self.rag_response["initial_response"]) - ] - filtered_metadata["did"] = pd.Categorical( - filtered_metadata["did"], - categories=self.rag_response["initial_response"], - ordered=True, - ) - filtered_metadata = filtered_metadata.sort_values("did").reset_index( - drop=True - ) + filtered_metadata = self._filter_before_rag(metadata) return filtered_metadata elif ( @@ -280,51 +222,97 @@ def parse_and_update_response(self, metadata: pd.DataFrame): "NumberOfFeatures", "NumberOfInstances", ] - print(self.structured_query_response) # Only for debugging. Comment later. + # print(self.structured_query_response) # Only for debugging. Comment later. if self.structured_query_response[0] is not None and isinstance( self.structured_query_response[1], dict ): # Safely attempt to access the "filter" key in the first element - if ( - self.structured_query_response[0].get("filter", None) - and self.database_filtered - ): - filtered_metadata = metadata[ - metadata["did"].isin(self.database_filtered) - ] - print("Showing database filtered data") - else: - filtered_metadata = metadata[ - metadata["did"].isin(self.rag_response["initial_response"]) - ] - print( - "Showing only rag response as filter is empty or none of the rag data satisfies filter conditions." - ) - filtered_metadata["did"] = pd.Categorical( - filtered_metadata["did"], - categories=self.rag_response["initial_response"], - ordered=True, - ) - filtered_metadata = filtered_metadata.sort_values("did").reset_index( - drop=True - ) + self._structured_query_on_success(metadata) else: - filtered_metadata = metadata[ - metadata["did"].isin(self.rag_response["initial_response"]) - ] - filtered_metadata["did"] = pd.Categorical( - filtered_metadata["did"], - categories=self.rag_response["initial_response"], - ordered=True, - ) - filtered_metadata = filtered_metadata.sort_values("did").reset_index( - drop=True - ) - print("Showing only rag response") + filtered_metadata = self._structured_query_on_fail(metadata) + # print("Showing only rag response") return filtered_metadata[["did", "name", *col_name]] + def _structured_query_on_fail(self, metadata): + filtered_metadata = metadata[ + metadata["did"].isin(self.rag_response["initial_response"]) + ] + filtered_metadata["did"] = pd.Categorical( + filtered_metadata["did"], + categories=self.rag_response["initial_response"], + ordered=True, + ) + filtered_metadata = filtered_metadata.sort_values("did").reset_index(drop=True) + + return filtered_metadata + + def _structured_query_on_success(self, metadata): + if ( + self.structured_query_response[0].get("filter", None) + and self.database_filtered + ): + filtered_metadata = metadata[metadata["did"].isin(self.database_filtered)] + # print("Showing database filtered data") + else: + filtered_metadata = metadata[ + metadata["did"].isin(self.rag_response["initial_response"]) + ] + # print( + # "Showing only rag response as filter is empty or none of the rag data satisfies filter conditions." + # ) + filtered_metadata["did"] = pd.Categorical( + filtered_metadata["did"], + categories=self.rag_response["initial_response"], + ordered=True, + ) + filtered_metadata = filtered_metadata.sort_values("did").reset_index(drop=True) + + def _filter_before_rag(self, metadata): + print("LLM filter before RAG") + llm_parser = LLMResponseParser(self.llm_response) + llm_parser.get_attributes_from_response() + filtered_metadata = llm_parser.update_subset_cols(metadata) + filtered_metadata = filtered_metadata[ + metadata["did"].isin(self.rag_response["initial_response"]) + ] + filtered_metadata["did"] = pd.Categorical( + filtered_metadata["did"], + categories=self.rag_response["initial_response"], + ordered=True, + ) + filtered_metadata = filtered_metadata.sort_values("did").reset_index(drop=True) + + return filtered_metadata + + def _rag_before_llm(self, metadata): + print("RAG before LLM filter.") + filtered_metadata = metadata[ + metadata["did"].isin(self.rag_response["initial_response"]) + ] + filtered_metadata["did"] = pd.Categorical( + filtered_metadata["did"], + categories=self.rag_response["initial_response"], + ordered=True, + ) + filtered_metadata = filtered_metadata.sort_values("did").reset_index(drop=True) + llm_parser = LLMResponseParser(self.llm_response) + return filtered_metadata, llm_parser + + def _no_filter(self, metadata): + filtered_metadata = metadata[ + metadata["did"].isin(self.rag_response["initial_response"]) + ] + filtered_metadata["did"] = pd.Categorical( + filtered_metadata["did"], + categories=self.rag_response["initial_response"], + ordered=True, + ) + filtered_metadata = filtered_metadata.sort_values("did").reset_index(drop=True) + + return filtered_metadata + class UILoader: """ @@ -335,12 +323,8 @@ def __init__(self, config_path): with open(config_path, "r") as file: # Load config self.config = json.load(file) - # self.message_box = message_box - # Paths and display information - # self.chatbot_input_max_chars = 500 - # Load metadata chroma database for structured query self.collec = load_chroma_metadata() @@ -358,74 +342,131 @@ def __init__(self, config_path): self.query_type = "Dataset" self.llm_filter = False self.paths = self.load_paths() + self.info = """ +

Machine learning research should be easily accessible and reusable. OpenML is an open platform for sharing datasets, algorithms, and experiments - to learn how to learn better, together.

+ """ + self.logo = "images/favicon.ico" + self.chatbot_display = "How do I do X using OpenML? / Find me a dataset about Y" if "messages" not in st.session_state: st.session_state.messages = [] - # def chat_entry(self): - # """ - # Description: Create the chat input box with a maximum character limit + # container for company description and logo + def generate_logo_header( + self, + ): + + col1, col2 = st.columns([1, 4]) + with col1: + st.image(self.logo, width=100) + with col2: + st.markdown( + self.info, + unsafe_allow_html=True, + ) + + def generate_complete_ui(self): + + self.generate_logo_header() + chat_container = st.container() + with chat_container: + with st.form(key="chat_form"): + user_input = st.text_input( + label="Query", placeholder=self.chatbot_display + ) + query_type = st.selectbox( + "Select Query Type", + ["General Query", "Dataset", "Flow"], + help="Are you looking for a dataset or a flow or just have a general query?", + ) + ai_filter = st.toggle( + "Use AI powered filtering", + value=True, + help="Uses an AI model to identify what columns might be useful to you.", + ) + st.form_submit_button(label="Search") - # """ - # return st.chat_input( - # self.chatbot_display, max_chars=self.chatbot_input_max_chars - # ) - def create_chat_interface(self, user_input, query_type=None): + self.create_chat_interface(user_input=None) + if user_input: + self.create_chat_interface( + user_input, query_type=query_type, ai_filter=ai_filter + ) + + def create_chat_interface(self, user_input, query_type=None, ai_filter=False): """ Description: Create the chat interface and display the chat history and results. Show the user input and the response from the OpenML Agent. """ self.query_type = query_type - # self.llm_filter = llm_filter + self.ai_filter = ai_filter + if user_input is None: with st.chat_message(name="ai"): st.write("OpenML Agent: ", "Hello! How can I help you today?") + st.write( + "Note that results are powered by local LLM models and may not be accurate. Please refer to the official OpenML website for accurate information." + ) # Handle user input if user_input: - st.session_state.messages.append({"role": "user", "content": user_input}) - with st.spinner("Waiting for results..."): - results = self.process_query_chat(user_input) + self._handle_user_input(user_input, query_type) - if not self.query_type == "General Query": - st.session_state.messages.append( + def _handle_user_input(self, user_input, query_type): + st.session_state.messages.append({"role": "user", "content": user_input}) + with st.spinner("Waiting for results..."): + results = self.process_query_chat(user_input) + + if not self.query_type == "General Query": + st.session_state.messages.append( {"role": "OpenML Agent", "content": results} ) - else: - with st.spinner("Fetching results..."): - with requests.get(results, stream=True) as r: - resp_contain = st.empty() - streamed_response = "" - for chunk in r.iter_content(chunk_size=1024): - if chunk: - streamed_response += chunk.decode("utf-8") - resp_contain.markdown(streamed_response) - resp_contain.empty() - st.session_state.messages.append( - {"role": "OpenML Agent", "content": streamed_response} - ) - + else: + self._stream_results(results) + # reverse messages to show the latest message at the top - reversed_messages = [] - for index in range(0, len(st.session_state.messages),2): - reversed_messages.insert(0, st.session_state.messages[index]) - reversed_messages.insert(1, st.session_state.messages[index + 1]) + reversed_messages = self._reverse_session_history() # Display chat history - for message in reversed_messages: - if query_type == "General Query": - pass - if message["role"] == "user": - with st.chat_message(name="user"): - self.display_results(message["content"], "user") - else: - with st.chat_message(name="ai"): - self.display_results(message["content"], "ai") - self.create_download_button() - + self._display_chat_history(query_type, reversed_messages) + self.create_download_button() + + def _display_chat_history(self, query_type, reversed_messages): + for message in reversed_messages: + if query_type == "General Query": + pass + if message["role"] == "user": + with st.chat_message(name="user"): + self.display_results(message["content"], "user") + else: + with st.chat_message(name="ai"): + self.display_results(message["content"], "ai") + + def _reverse_session_history(self): + reversed_messages = [] + for index in range(0, len(st.session_state.messages), 2): + reversed_messages.insert(0, st.session_state.messages[index]) + reversed_messages.insert(1, st.session_state.messages[index + 1]) + return reversed_messages + + def _stream_results(self, results): + with st.spinner("Fetching results..."): + with requests.get(results, stream=True) as r: + resp_contain = st.empty() + streamed_response = "" + for chunk in r.iter_content(chunk_size=1024): + if chunk: + streamed_response += chunk.decode("utf-8") + resp_contain.markdown(streamed_response) + resp_contain.empty() + st.session_state.messages.append( + {"role": "OpenML Agent", "content": streamed_response} + ) + @st.experimental_fragment() def create_download_button(self): - data = "\n".join([str(message["content"]) for message in st.session_state.messages]) + data = "\n".join( + [str(message["content"]) for message in st.session_state.messages] + ) st.download_button( label="Download chat history", data=data, @@ -440,10 +481,8 @@ def display_results(self, initial_response, role): try: st.dataframe(initial_response) - # self.message_box.chat_message(role).write(st.dataframe(initial_response)) except: st.write(initial_response) - # self.message_box.chat_message(role).write(initial_response) # Function to handle query processing def process_query_chat(self, query): @@ -457,63 +496,48 @@ def process_query_chat(self, query): ) if self.query_type == "Dataset" or self.query_type == "Flow": - if config["structured_query"]: - # get structured query - response_parser.fetch_structured_query(self.query_type, query) - try: - # get rag response - # using original query instead of extracted topics. - response_parser.fetch_rag_response( - self.query_type, - response_parser.structured_query_response[0]["query"], - ) - - if response_parser.structured_query_response: - st.write( - "Detected Filter(s): ", - json.dumps( - response_parser.structured_query_response[0].get( - "filter", None - ) - ), - ) - else: - st.write("Detected Filter(s): ", None) - # st.write("Detected Topics: ", response_parser.structured_query_response[0].get("query", None)) - if response_parser.structured_query_response[1].get("filter"): - - with st.spinner("Applying LLM Detected Filter(s)..."): - response_parser.database_filter( - response_parser.structured_query_response[1]["filter"], - collec, - ) - except: - # fallback to RAG response - response_parser.fetch_rag_response(self.query_type, query) - else: - # get rag response + if not self.ai_filter: response_parser.fetch_rag_response(self.query_type, query) - if self.llm_filter: - # get llm response - response_parser.fetch_llm_response(query) + return response_parser.parse_and_update_response(self.data_metadata) + else: + # get structured query + self._display_structured_query_results(query, response_parser) results = response_parser.parse_and_update_response(self.data_metadata) return results + elif self.query_type == "General Query": - # response_parser.fetch_documentation_query(query) - # return response_parser.documentation_response - documentation_response_path = ( - self.paths["documentation_query"]["local"] + query + # Return documentation response path + return self.paths["documentation_query"]["local"] + query + + def _display_structured_query_results(self, query, response_parser): + response_parser.fetch_structured_query(self.query_type, query) + try: + # get rag response + # using original query instead of extracted topics. + response_parser.fetch_rag_response( + self.query_type, + response_parser.structured_query_response[0]["query"], ) - # with requests.get(documentation_response_path, stream=True) as r: - # resp_contain = st.empty() - # streamed_response = "" - # for chunk in r.iter_content(chunk_size=1024): - # if chunk: - # streamed_response += chunk.decode("utf-8") - # resp_contain.markdown(streamed_response) - # return requests.get(documentation_response_path, stream=True) - return documentation_response_path + + if response_parser.structured_query_response: + st.write( + "Detected Filter(s): ", + json.dumps( + response_parser.structured_query_response[0].get("filter", None) + ), + ) + else: + st.write("Detected Filter(s): ", None) + if response_parser.structured_query_response[1].get("filter"): + with st.spinner("Applying LLM Detected Filter(s)..."): + response_parser.database_filter( + response_parser.structured_query_response[1]["filter"], + collec, + ) + except: + # fallback to RAG response + response_parser.fetch_rag_response(self.query_type, query) def load_paths(self): """ diff --git a/structured_query/chroma_store_utilis.py b/structured_query/chroma_store_utilis.py index 9765153..494b118 100644 --- a/structured_query/chroma_store_utilis.py +++ b/structured_query/chroma_store_utilis.py @@ -1,11 +1,11 @@ -import sqlalchemy -import pandas as pd +import sys + import chromadb +import pandas as pd +import sqlalchemy from langchain_community.vectorstores.chroma import Chroma from tqdm.auto import tqdm -import sys - sys.path.append("../") sys.path.append("../backend/") from backend.modules.utils import load_config_and_device @@ -44,4 +44,4 @@ def load_chroma_metadata(): if collec.get(ids=ids) == []: collec.add(ids=ids, documents=documents, metadatas=metadatas) - return collec \ No newline at end of file + return collec diff --git a/structured_query/llm_service_structured_query.py b/structured_query/llm_service_structured_query.py index f44c40e..bdb3be0 100644 --- a/structured_query/llm_service_structured_query.py +++ b/structured_query/llm_service_structured_query.py @@ -1,10 +1,11 @@ import json + from fastapi import FastAPI, HTTPException -from llm_service_structured_query_utils import create_query_structuring_chain from fastapi.responses import JSONResponse from httpx import ConnectTimeout -from tenacity import retry, retry_if_exception_type, stop_after_attempt from langchain_community.query_constructors.chroma import ChromaTranslator +from llm_service_structured_query_utils import create_query_structuring_chain +from tenacity import retry, retry_if_exception_type, stop_after_attempt document_content_description = "Metadata of datasets for various machine learning applications fetched from OpenML platform." @@ -19,9 +20,9 @@ try: print("[INFO] Sending first query to structured query llm to avoid cold start.") - + query = "mushroom data with 2 classess" - response = chain.invoke({"query": query}) + response = chain.invoke({"query": query}) obj = ChromaTranslator() filter_condition = obj.visit_structured_query(structured_query=response)[1] print(response, filter_condition) @@ -45,12 +46,11 @@ async def get_structured_query(query: str): print(response) obj = ChromaTranslator() filter_condition = obj.visit_structured_query(structured_query=response)[1] - + except Exception as e: - print(f"An error occurred: ", HTTPException(status_code=500, detail=f"An error occurred: {e}")) - + print( + f"An error occurred: ", + HTTPException(status_code=500, detail=f"An error occurred: {e}"), + ) + return response, filter_condition - - - - \ No newline at end of file diff --git a/structured_query/llm_service_structured_query_utils.py b/structured_query/llm_service_structured_query_utils.py index 88a43f3..d889bd2 100644 --- a/structured_query/llm_service_structured_query_utils.py +++ b/structured_query/llm_service_structured_query_utils.py @@ -1,14 +1,13 @@ import json import sys + +from langchain.chains.query_constructor.base import ( + get_query_constructor_prompt, load_query_constructor_runnable) from langchain_community.chat_models import ChatOllama +from structured_query_examples import examples # from langchain_ollama.llms import OllamaLLM -from langchain.chains.query_constructor.base import ( - get_query_constructor_prompt, - load_query_constructor_runnable, -) -from structured_query_examples import examples sys.path.append("../") @@ -42,4 +41,4 @@ def create_query_structuring_chain( fix_invalid=True, ) - return chain \ No newline at end of file + return chain diff --git a/structured_query/structuring_query.py b/structured_query/structuring_query.py index c667d28..84907cb 100644 --- a/structured_query/structuring_query.py +++ b/structured_query/structuring_query.py @@ -1,9 +1,7 @@ import json from langchain.chains.query_constructor.base import ( - get_query_constructor_prompt, - load_query_constructor_runnable, -) + get_query_constructor_prompt, load_query_constructor_runnable) from structured_query.structured_query_examples import examples @@ -45,4 +43,4 @@ def structuring_query(query: str): structured_query = chain.invoke(query) - return structured_query.query, structured_query.filter \ No newline at end of file + return structured_query.query, structured_query.filter diff --git a/useful_scripts/test_selfquery.ipynb b/useful_scripts/test_selfquery.ipynb index a03c5aa..48ae9e6 100644 --- a/useful_scripts/test_selfquery.ipynb +++ b/useful_scripts/test_selfquery.ipynb @@ -823,9 +823,7 @@ } ], "source": [ - "structured_query = chain.invoke(\n", - " {\"query\": query}\n", - " )" + "structured_query = chain.invoke({\"query\": query})" ] }, { @@ -865,9 +863,7 @@ ], "source": [ "try:\n", - " structured_query = chain.invoke(\n", - " {\"query\": query}\n", - " )\n", + " structured_query = chain.invoke({\"query\": query})\n", " print(structured_query)\n", "except Exception as e:\n", " error = e" @@ -1286,7 +1282,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "DEFAULT_SCHEMA = \"\"\"\\\n", "<< Structured Request Schema >>\n", "When responding use a markdown code snippet with a JSON object formatted in the following schema:\n", @@ -1310,7 +1305,7 @@ "metadata": {}, "outputs": [], "source": [ - "schema = DEFAULT_SCHEMA.format(output = \"hellow\")" + "schema = DEFAULT_SCHEMA.format(output=\"hellow\")" ] }, { @@ -1386,6 +1381,7 @@ ")\n", "from langchain.chains.query_constructor.base import StructuredQueryOutputParser\n", "\n", + "\n", "def create_custom_prompt_template():\n", " prompt_template = (\n", " \"Given the initial extraction: {initial_result}, \"\n", @@ -1419,7 +1415,7 @@ " )\n", " return PromptTemplate(\n", " input_variables=[\"initial_result\", \"metric\", \"question\"],\n", - " template=prompt_template\n", + " template=prompt_template,\n", " )" ] }, @@ -1429,17 +1425,23 @@ "metadata": {}, "outputs": [], "source": [ - "def create_query_structuring_chain_with_custom_prompt(document_content_description, content_attr, model=\"llama3.1\"):\n", + "def create_query_structuring_chain_with_custom_prompt(\n", + " document_content_description, content_attr, model=\"llama3.1\"\n", + "):\n", " # Filter the attribute info based on content_attr\n", - " filter_attribute_info = tuple(ai for ai in attribute_info if ai[\"name\"] in content_attr)\n", - " \n", + " filter_attribute_info = tuple(\n", + " ai for ai in attribute_info if ai[\"name\"] in content_attr\n", + " )\n", + "\n", " # Create a custom prompt template\n", " custom_prompt = create_custom_prompt_template()\n", - " \n", + "\n", " # Prepare the input for the custom prompt\n", - " attributes = \", \".join([f\"{attr['name']}: {attr['description']}\" for attr in filter_attribute_info])\n", + " attributes = \", \".join(\n", + " [f\"{attr['name']}: {attr['description']}\" for attr in filter_attribute_info]\n", + " )\n", " # examples_formatted = \"\\n\\n\".join([f\"Input: {ex['input']}\\nOutput: {ex['structured_query']}\" for ex in examples])\n", - " \n", + "\n", " # Create a chain with the custom prompt\n", " chain = LLMChain(\n", " llm=ChatOllama(model=model),\n", @@ -1449,15 +1451,17 @@ " allowed_operators=tuple(Operator),\n", " allowed_attributes=[attr[\"name\"] for attr in filter_attribute_info],\n", " fix_invalid=True,\n", - " )\n", + " ),\n", " )\n", - " return chain.run({\n", - " \"initial_result\": structured_query, \n", - " \"metric\": 0.8, \n", - " \"question\": \"find some mushroom dataset with less than 10k size and json format\",\n", - " \"allowed_comparators\": tuple(Comparator),\n", - " \"allowed_operators\": tuple(Operator)\n", - " })" + " return chain.run(\n", + " {\n", + " \"initial_result\": structured_query,\n", + " \"metric\": 0.8,\n", + " \"question\": \"find some mushroom dataset with less than 10k size and json format\",\n", + " \"allowed_comparators\": tuple(Comparator),\n", + " \"allowed_operators\": tuple(Operator),\n", + " }\n", + " )" ] }, { @@ -1466,7 +1470,10 @@ "metadata": {}, "outputs": [], "source": [ - "create_query_structuring_chain_with_custom_prompt(document_content_description=document_content_description, content_attr=attribute_info)" + "create_query_structuring_chain_with_custom_prompt(\n", + " document_content_description=document_content_description,\n", + " content_attr=attribute_info,\n", + ")" ] }, { @@ -1513,7 +1520,14 @@ } ], "source": [ - "chain.invoke({\"query\": \"find some mushroom dataset with less than 10k size and json format\", \"initial_result\": structured_query, \"metric\":0.8, \"question\":\"find some mushroom dataset with less than 10k size and json format\"})" + "chain.invoke(\n", + " {\n", + " \"query\": \"find some mushroom dataset with less than 10k size and json format\",\n", + " \"initial_result\": structured_query,\n", + " \"metric\": 0.8,\n", + " \"question\": \"find some mushroom dataset with less than 10k size and json format\",\n", + " }\n", + ")" ] }, { @@ -1522,17 +1536,23 @@ "metadata": {}, "outputs": [], "source": [ - "def create_query_structuring_chain_with_custom_prompt(document_content_description, content_attr, model=\"llama3.1\"):\n", + "def create_query_structuring_chain_with_custom_prompt(\n", + " document_content_description, content_attr, model=\"llama3.1\"\n", + "):\n", " # Filter the attribute info based on content_attr\n", - " filter_attribute_info = tuple(ai for ai in attribute_info if ai[\"name\"] in content_attr)\n", - " \n", + " filter_attribute_info = tuple(\n", + " ai for ai in attribute_info if ai[\"name\"] in content_attr\n", + " )\n", + "\n", " # Create a custom prompt template\n", " custom_prompt = create_custom_prompt_template()\n", - " \n", + "\n", " # Prepare the input for the custom prompt\n", - " attributes = \", \".join([f\"{attr['name']}: {attr['description']}\" for attr in filter_attribute_info])\n", + " attributes = \", \".join(\n", + " [f\"{attr['name']}: {attr['description']}\" for attr in filter_attribute_info]\n", + " )\n", " # examples_formatted = \"\\n\\n\".join([f\"Input: {ex['input']}\\nOutput: {ex['structured_query']}\" for ex in examples])\n", - " \n", + "\n", " # Create a chain with the custom prompt\n", " chain = LLMChain(\n", " llm=ChatOllama(model=model),\n", @@ -1542,39 +1562,42 @@ " allowed_operators=tuple(Operator),\n", " allowed_attributes=[attr[\"name\"] for attr in filter_attribute_info],\n", " fix_invalid=True,\n", - " )\n", + " ),\n", + " )\n", + "\n", + " return chain.run(\n", + " {\n", + " \"document_contents\": document_content_description,\n", + " \"attributes\": attributes,\n", + " \"examples\": None,\n", + " }\n", " )\n", "\n", - " return chain.run({\n", - " \"document_contents\": document_content_description,\n", - " \"attributes\": attributes,\n", - " \"examples\": None\n", - " })\n", "\n", "def validate_and_retry(structured_query_output, model=\"llama3.1\"):\n", " validation_prompt = (\n", " \"Here is a structured query output: {output}\\n\"\n", " \"Please check if it is a valid JSON. If it is verbose or incorrect, fix it.\"\n", " )\n", - " \n", + "\n", " chain = LLMChain(\n", " llm=ChatOllama(model=model),\n", - " prompt=PromptTemplate(\n", - " input_variables=[\"output\"],\n", - " template=validation_prompt\n", - " )\n", + " prompt=PromptTemplate(input_variables=[\"output\"], template=validation_prompt),\n", " )\n", - " \n", + "\n", " valid_output = chain.run({\"output\": structured_query_output})\n", " return valid_output\n", "\n", + "\n", "document_description = \"Metadata of machine learning datasets including status, number of classes, number of instances, number of features, and combined information about the dataset.\"\n", "content_attributes = [\"status\", \"number_of_classes\", \"combined_information\"]\n", "\n", - "initial_output = create_query_structuring_chain_with_custom_prompt(document_description, content_attributes)\n", + "initial_output = create_query_structuring_chain_with_custom_prompt(\n", + " document_description, content_attributes\n", + ")\n", "print(initial_output)\n", "final_output = validate_and_retry(initial_output)\n", - "print(final_output)\n" + "print(final_output)" ] } ], diff --git a/useful_scripts/test_strutured_query_pipeline.ipynb b/useful_scripts/test_strutured_query_pipeline.ipynb index fa38a55..06a7444 100644 --- a/useful_scripts/test_strutured_query_pipeline.ipynb +++ b/useful_scripts/test_strutured_query_pipeline.ipynb @@ -360,9 +360,6 @@ " return json.load(file)\n", "\n", "\n", - "\n", - "\n", - "\n", "# Override the load_paths method\n", "def custom_load_paths():\n", " with open(custom_paths_json_path, \"r\") as file:\n",