diff --git a/backend/__init__.py b/backend/__init__.py index 4a02c7d..5c665c9 100644 --- a/backend/__init__.py +++ b/backend/__init__.py @@ -1,4 +1,4 @@ __version__ = "0.1.0" from .modules.rag_llm import * -from .modules.vector_store_utils import * from .modules.utils import * +from .modules.vector_store_utils import * diff --git a/backend/modules/metadata_utils.py b/backend/modules/metadata_utils.py index 00860d8..e9a9ff0 100644 --- a/backend/modules/metadata_utils.py +++ b/backend/modules/metadata_utils.py @@ -2,7 +2,6 @@ import os import pickle - # from pqdm.processes import pqdm from typing import Sequence, Tuple, Union diff --git a/backend/modules/results_gen.py b/backend/modules/results_gen.py index 9618be8..52ed937 100644 --- a/backend/modules/results_gen.py +++ b/backend/modules/results_gen.py @@ -9,9 +9,8 @@ import pandas as pd from flashrank import Ranker, RerankRequest from langchain.chains.retrieval_qa.base import RetrievalQA -from langchain_community.document_transformers.long_context_reorder import ( - LongContextReorder, -) +from langchain_community.document_transformers.long_context_reorder import \ + LongContextReorder from langchain_core.documents import BaseDocumentTransformer, Document from tqdm import tqdm diff --git a/docs/Rag Pipeline/Developer Tutorials/change_model.py b/docs/Rag Pipeline/Developer Tutorials/change_model.py index 01ba86f..fd8cc45 100644 --- a/docs/Rag Pipeline/Developer Tutorials/change_model.py +++ b/docs/Rag Pipeline/Developer Tutorials/change_model.py @@ -16,13 +16,15 @@ # - How would you use a different embedding and llm model? from __future__ import annotations -from langchain_community.cache import SQLiteCache + import os import sys + import chromadb +from langchain_community.cache import SQLiteCache -from backend.modules.utils import load_config_and_device from backend.modules.rag_llm import QASetup +from backend.modules.utils import load_config_and_device # ## Initial config diff --git a/documentation_bot/documentation_query.py b/documentation_bot/documentation_query.py index c45aae2..b3ede0e 100644 --- a/documentation_bot/documentation_query.py +++ b/documentation_bot/documentation_query.py @@ -1,12 +1,13 @@ -from utils import ChromaStore, Crawler import os +import uuid + from fastapi import FastAPI -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse from httpx import ConnectTimeout from tenacity import retry, retry_if_exception_type, stop_after_attempt -import uuid +from utils import ChromaStore, Crawler -#TODO : make this into a separate thing using config +# TODO : make this into a separate thing using config recrawl_websites = False crawled_files_data_path = "../data/crawler/crawled_data.csv" @@ -41,6 +42,17 @@ app = FastAPI() session_id = str(uuid.uuid4()) + +def stream_response(response): + for line in response: + try: + yield str(line["answer"]) + except GeneratorExit: + break + except: + yield "" + + @app.get("/documentationquery/{query}", response_class=JSONResponse) @retry(stop=stop_after_attempt(3), retry=retry_if_exception_type(ConnectTimeout)) async def get_documentation_query(query: str): @@ -49,4 +61,5 @@ async def get_documentation_query(query: str): chroma_store.setup_inference(session_id) response = chroma_store.openml_page_search(input=query) - return JSONResponse(content=response) + # return JSONResponse(content=response) + return StreamingResponse(stream_response(response), media_type="text/event-stream") diff --git a/documentation_bot/utils.py b/documentation_bot/utils.py index 6a1459b..9931649 100644 --- a/documentation_bot/utils.py +++ b/documentation_bot/utils.py @@ -1,24 +1,25 @@ -import requests -from bs4 import BeautifulSoup import csv import os +from urllib.parse import urljoin + import pandas as pd +import requests import torch -from langchain.text_splitter import RecursiveCharacterTextSplitter +from bs4 import BeautifulSoup +from langchain.chains import (create_history_aware_retriever, + create_retrieval_chain) +from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.document_loaders import DataFrameLoader +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain_community.chat_message_histories import ChatMessageHistory from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain_community.vectorstores.chroma import Chroma -from urllib.parse import urljoin -from tqdm.auto import tqdm - -from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory - -from langchain_ollama import ChatOllama from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables.history import RunnableWithMessageHistory -from langchain.chains import create_history_aware_retriever, create_retrieval_chain -from langchain.chains.combine_documents import create_stuff_documents_chain +from langchain_ollama import ChatOllama +from tqdm.auto import tqdm + def find_device() -> str: """ @@ -30,7 +31,10 @@ def find_device() -> str: return "mps" return "cpu" -store = {} #TODO : change this to a nicer way of storing the chat history + +store = {} # TODO : change this to a nicer way of storing the chat history + + def get_session_history(session_id: str) -> BaseChatMessageHistory: # print("this is the session id", session_id) if session_id not in store: @@ -38,6 +42,7 @@ def get_session_history(session_id: str) -> BaseChatMessageHistory: # print("this is the store id", store[session_id]) return store[session_id] + class Crawler: """ Description: This class is used to crawl the OpenML website and gather both code and general information for a bot. @@ -279,7 +284,7 @@ def setup_inference(self, session_id: str) -> None: self.store = {} self.session_id = session_id - def openml_page_search(self, input: str) -> str: + def openml_page_search(self, input: str): vectorstore = Chroma( persist_directory=self.chroma_file_path, @@ -326,10 +331,16 @@ def openml_page_search(self, input: str) -> str: output_messages_key="answer", ) - answer = conversational_rag_chain.invoke( + # answer = conversational_rag_chain.invoke( + # {"input": f"{input}"}, + # config={ + # "configurable": {"session_id": self.session_id} + # }, # constructs a key "abc123" in `store`. + # )["answer"] + answer = conversational_rag_chain.stream( {"input": f"{input}"}, config={ "configurable": {"session_id": self.session_id} }, # constructs a key "abc123" in `store`. - )["answer"] + ) return answer diff --git a/evaluation/experiments.py b/evaluation/experiments.py index cdbe1ac..247e34f 100644 --- a/evaluation/experiments.py +++ b/evaluation/experiments.py @@ -1,6 +1,7 @@ -from training_utils import * from concurrent.futures import ThreadPoolExecutor, as_completed + from tqdm.auto import tqdm +from training_utils import * def exp_0(process_query_elastic_search, eval_path, query_key_dict): diff --git a/evaluation/run_all_training.py b/evaluation/run_all_training.py index 772fbd1..eb5ce94 100644 --- a/evaluation/run_all_training.py +++ b/evaluation/run_all_training.py @@ -5,11 +5,13 @@ from __future__ import annotations import json -from pathlib import Path import os -from backend.modules.utils import load_config_and_device -from training_utils import * +from pathlib import Path + from experiments import * +from training_utils import * + +from backend.modules.utils import load_config_and_device if __name__ == "__main__": # %% diff --git a/evaluation/training_utils.py b/evaluation/training_utils.py index 90c47fe..28aa5f3 100644 --- a/evaluation/training_utils.py +++ b/evaluation/training_utils.py @@ -19,13 +19,12 @@ # change the path to the backend directory sys.path.append(os.path.join(os.path.dirname("."), "../backend/")) import requests +# add modules from ui_utils +from tqdm.auto import tqdm # %% from backend.modules.rag_llm import * from backend.modules.results_gen import * - -# add modules from ui_utils -from tqdm.auto import tqdm from frontend.ui_utils import * diff --git a/frontend/ui.py b/frontend/ui.py index afa08ed..14dec5f 100644 --- a/frontend/ui.py +++ b/frontend/ui.py @@ -4,37 +4,15 @@ from ui_utils import * # Streamlit Chat Interface -logo = "images/favicon.ico" page_title = "OpenML : A worldwide machine learning lab" -info = """ -
Machine learning research should be easily accessible and reusable. OpenML is an open platform for sharing datasets, algorithms, and experiments - to learn how to learn better, together.
Ask me anything about OpenML or search for a dataset ...
Machine learning research should be easily accessible and reusable. OpenML is an open platform for sharing datasets, algorithms, and experiments - to learn how to learn better, together.
+ """ + self.logo = "images/favicon.ico" + self.chatbot_display = "How do I do X using OpenML? / Find me a dataset about Y" if "messages" not in st.session_state: st.session_state.messages = [] - # def chat_entry(self): - # """ - # Description: Create the chat input box with a maximum character limit + # container for company description and logo + def generate_logo_header( + self, + ): + + col1, col2 = st.columns([1, 4]) + with col1: + st.image(self.logo, width=100) + with col2: + st.markdown( + self.info, + unsafe_allow_html=True, + ) - # """ - # return st.chat_input( - # self.chatbot_display, max_chars=self.chatbot_input_max_chars - # ) + def generate_complete_ui(self): - def create_chat_interface(self, user_input, query_type=None): + self.generate_logo_header() + chat_container = st.container() + with chat_container: + with st.form(key="chat_form"): + user_input = st.text_input( + label="Query", placeholder=self.chatbot_display + ) + query_type = st.selectbox( + "Select Query Type", + ["General Query", "Dataset", "Flow"], + help="Are you looking for a dataset or a flow or just have a general query?", + ) + ai_filter = st.toggle( + "Use AI powered filtering", + value=True, + help="Uses an AI model to identify what columns might be useful to you.", + ) + st.form_submit_button(label="Search") + + self.create_chat_interface(user_input=None) + if user_input: + self.create_chat_interface( + user_input, query_type=query_type, ai_filter=ai_filter + ) + + def create_chat_interface(self, user_input, query_type=None, ai_filter=False): """ Description: Create the chat interface and display the chat history and results. Show the user input and the response from the OpenML Agent. """ self.query_type = query_type - # self.llm_filter = llm_filter + self.ai_filter = ai_filter + if user_input is None: - with st.chat_message(name = "ai"): + with st.chat_message(name="ai"): st.write("OpenML Agent: ", "Hello! How can I help you today?") - + st.write( + "Note that results are powered by local LLM models and may not be accurate. Please refer to the official OpenML website for accurate information." + ) + # Handle user input if user_input: - st.session_state.messages.append({"role": "user", "content": user_input}) - with st.spinner("Waiting for results..."): - results = self.process_query_chat(user_input) + self._handle_user_input(user_input, query_type) + + def _handle_user_input(self, user_input, query_type): + st.session_state.messages.append({"role": "user", "content": user_input}) + with st.spinner("Waiting for results..."): + results = self.process_query_chat(user_input) + if not self.query_type == "General Query": st.session_state.messages.append( - {"role": "OpenML Agent", "content": results} - ) + {"role": "OpenML Agent", "content": results} + ) + else: + self._stream_results(results) + + # reverse messages to show the latest message at the top + reversed_messages = self._reverse_session_history() - # Display chat history - for message in st.session_state.messages: + # Display chat history + self._display_chat_history(query_type, reversed_messages) + self.create_download_button() + + def _display_chat_history(self, query_type, reversed_messages): + for message in reversed_messages: + if query_type == "General Query": + pass if message["role"] == "user": - with st.chat_message(name = "user"): + with st.chat_message(name="user"): self.display_results(message["content"], "user") else: - with st.chat_message(name = "ai"): + with st.chat_message(name="ai"): self.display_results(message["content"], "ai") - def display_results(self,initial_response, role): + def _reverse_session_history(self): + reversed_messages = [] + for index in range(0, len(st.session_state.messages), 2): + reversed_messages.insert(0, st.session_state.messages[index]) + reversed_messages.insert(1, st.session_state.messages[index + 1]) + return reversed_messages + + def _stream_results(self, results): + with st.spinner("Fetching results..."): + with requests.get(results, stream=True) as r: + resp_contain = st.empty() + streamed_response = "" + for chunk in r.iter_content(chunk_size=1024): + if chunk: + streamed_response += chunk.decode("utf-8") + resp_contain.markdown(streamed_response) + resp_contain.empty() + st.session_state.messages.append( + {"role": "OpenML Agent", "content": streamed_response} + ) + + @st.experimental_fragment() + def create_download_button(self): + data = "\n".join( + [str(message["content"]) for message in st.session_state.messages] + ) + st.download_button( + label="Download chat history", + data=data, + file_name="chat_history.txt", + ) + + def display_results(self, initial_response, role): """ Description: Display the results in a DataFrame """ # st.write("OpenML Agent: ") - + try: st.dataframe(initial_response) - # self.message_box.chat_message(role).write(st.dataframe(initial_response)) except: st.write(initial_response) - # self.message_box.chat_message(role).write(initial_response) # Function to handle query processing def process_query_chat(self, query): @@ -433,40 +496,52 @@ def process_query_chat(self, query): ) if self.query_type == "Dataset" or self.query_type == "Flow": - if config["structured_query"]: - # get structured query - response_parser.fetch_structured_query(self.query_type, query) - try: - # get rag response - # using original query instead of extracted topics. - response_parser.fetch_rag_response( - self.query_type, - response_parser.structured_query_response[0]["query"], - ) - - if response_parser.structured_query_response: - st.write("Detected Filter(s): ", json.dumps(response_parser.structured_query_response[0].get("filter", None))) - else: - st.write("Detected Filter(s): ", None) - # st.write("Detected Topics: ", response_parser.structured_query_response[0].get("query", None)) - if response_parser.structured_query_response[1].get("filter"): - - with st.spinner("Applying LLM Detected Filter(s)..."): - response_parser.database_filter( - response_parser.structured_query_response[1]["filter"], collec - ) - except: - # fallback to RAG response - response_parser.fetch_rag_response(self.query_type, query) - else: - # get rag response + if not self.ai_filter: response_parser.fetch_rag_response(self.query_type, query) - if self.llm_filter: - # get llm response - response_parser.fetch_llm_response(query) + return response_parser.parse_and_update_response(self.data_metadata) + else: + # get structured query + self._display_structured_query_results(query, response_parser) results = response_parser.parse_and_update_response(self.data_metadata) return results + elif self.query_type == "General Query": - response_parser.fetch_documentation_query(query) - return response_parser.documentation_response + # Return documentation response path + return self.paths["documentation_query"]["local"] + query + + def _display_structured_query_results(self, query, response_parser): + response_parser.fetch_structured_query(self.query_type, query) + try: + # get rag response + # using original query instead of extracted topics. + response_parser.fetch_rag_response( + self.query_type, + response_parser.structured_query_response[0]["query"], + ) + + if response_parser.structured_query_response: + st.write( + "Detected Filter(s): ", + json.dumps( + response_parser.structured_query_response[0].get("filter", None) + ), + ) + else: + st.write("Detected Filter(s): ", None) + if response_parser.structured_query_response[1].get("filter"): + with st.spinner("Applying LLM Detected Filter(s)..."): + response_parser.database_filter( + response_parser.structured_query_response[1]["filter"], + collec, + ) + except: + # fallback to RAG response + response_parser.fetch_rag_response(self.query_type, query) + + def load_paths(self): + """ + Description: Load paths from paths.json + """ + with open("paths.json", "r") as file: + return json.load(file) diff --git a/structured_query/chroma_store_utilis.py b/structured_query/chroma_store_utilis.py index 9765153..494b118 100644 --- a/structured_query/chroma_store_utilis.py +++ b/structured_query/chroma_store_utilis.py @@ -1,11 +1,11 @@ -import sqlalchemy -import pandas as pd +import sys + import chromadb +import pandas as pd +import sqlalchemy from langchain_community.vectorstores.chroma import Chroma from tqdm.auto import tqdm -import sys - sys.path.append("../") sys.path.append("../backend/") from backend.modules.utils import load_config_and_device @@ -44,4 +44,4 @@ def load_chroma_metadata(): if collec.get(ids=ids) == []: collec.add(ids=ids, documents=documents, metadatas=metadatas) - return collec \ No newline at end of file + return collec diff --git a/structured_query/llm_service_structured_query.py b/structured_query/llm_service_structured_query.py index f44c40e..bdb3be0 100644 --- a/structured_query/llm_service_structured_query.py +++ b/structured_query/llm_service_structured_query.py @@ -1,10 +1,11 @@ import json + from fastapi import FastAPI, HTTPException -from llm_service_structured_query_utils import create_query_structuring_chain from fastapi.responses import JSONResponse from httpx import ConnectTimeout -from tenacity import retry, retry_if_exception_type, stop_after_attempt from langchain_community.query_constructors.chroma import ChromaTranslator +from llm_service_structured_query_utils import create_query_structuring_chain +from tenacity import retry, retry_if_exception_type, stop_after_attempt document_content_description = "Metadata of datasets for various machine learning applications fetched from OpenML platform." @@ -19,9 +20,9 @@ try: print("[INFO] Sending first query to structured query llm to avoid cold start.") - + query = "mushroom data with 2 classess" - response = chain.invoke({"query": query}) + response = chain.invoke({"query": query}) obj = ChromaTranslator() filter_condition = obj.visit_structured_query(structured_query=response)[1] print(response, filter_condition) @@ -45,12 +46,11 @@ async def get_structured_query(query: str): print(response) obj = ChromaTranslator() filter_condition = obj.visit_structured_query(structured_query=response)[1] - + except Exception as e: - print(f"An error occurred: ", HTTPException(status_code=500, detail=f"An error occurred: {e}")) - + print( + f"An error occurred: ", + HTTPException(status_code=500, detail=f"An error occurred: {e}"), + ) + return response, filter_condition - - - - \ No newline at end of file diff --git a/structured_query/llm_service_structured_query_utils.py b/structured_query/llm_service_structured_query_utils.py index 88a43f3..d889bd2 100644 --- a/structured_query/llm_service_structured_query_utils.py +++ b/structured_query/llm_service_structured_query_utils.py @@ -1,14 +1,13 @@ import json import sys + +from langchain.chains.query_constructor.base import ( + get_query_constructor_prompt, load_query_constructor_runnable) from langchain_community.chat_models import ChatOllama +from structured_query_examples import examples # from langchain_ollama.llms import OllamaLLM -from langchain.chains.query_constructor.base import ( - get_query_constructor_prompt, - load_query_constructor_runnable, -) -from structured_query_examples import examples sys.path.append("../") @@ -42,4 +41,4 @@ def create_query_structuring_chain( fix_invalid=True, ) - return chain \ No newline at end of file + return chain diff --git a/structured_query/structuring_query.py b/structured_query/structuring_query.py index c667d28..84907cb 100644 --- a/structured_query/structuring_query.py +++ b/structured_query/structuring_query.py @@ -1,9 +1,7 @@ import json from langchain.chains.query_constructor.base import ( - get_query_constructor_prompt, - load_query_constructor_runnable, -) + get_query_constructor_prompt, load_query_constructor_runnable) from structured_query.structured_query_examples import examples @@ -45,4 +43,4 @@ def structuring_query(query: str): structured_query = chain.invoke(query) - return structured_query.query, structured_query.filter \ No newline at end of file + return structured_query.query, structured_query.filter diff --git a/useful_scripts/test_selfquery.ipynb b/useful_scripts/test_selfquery.ipynb index a03c5aa..48ae9e6 100644 --- a/useful_scripts/test_selfquery.ipynb +++ b/useful_scripts/test_selfquery.ipynb @@ -823,9 +823,7 @@ } ], "source": [ - "structured_query = chain.invoke(\n", - " {\"query\": query}\n", - " )" + "structured_query = chain.invoke({\"query\": query})" ] }, { @@ -865,9 +863,7 @@ ], "source": [ "try:\n", - " structured_query = chain.invoke(\n", - " {\"query\": query}\n", - " )\n", + " structured_query = chain.invoke({\"query\": query})\n", " print(structured_query)\n", "except Exception as e:\n", " error = e" @@ -1286,7 +1282,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "DEFAULT_SCHEMA = \"\"\"\\\n", "<< Structured Request Schema >>\n", "When responding use a markdown code snippet with a JSON object formatted in the following schema:\n", @@ -1310,7 +1305,7 @@ "metadata": {}, "outputs": [], "source": [ - "schema = DEFAULT_SCHEMA.format(output = \"hellow\")" + "schema = DEFAULT_SCHEMA.format(output=\"hellow\")" ] }, { @@ -1386,6 +1381,7 @@ ")\n", "from langchain.chains.query_constructor.base import StructuredQueryOutputParser\n", "\n", + "\n", "def create_custom_prompt_template():\n", " prompt_template = (\n", " \"Given the initial extraction: {initial_result}, \"\n", @@ -1419,7 +1415,7 @@ " )\n", " return PromptTemplate(\n", " input_variables=[\"initial_result\", \"metric\", \"question\"],\n", - " template=prompt_template\n", + " template=prompt_template,\n", " )" ] }, @@ -1429,17 +1425,23 @@ "metadata": {}, "outputs": [], "source": [ - "def create_query_structuring_chain_with_custom_prompt(document_content_description, content_attr, model=\"llama3.1\"):\n", + "def create_query_structuring_chain_with_custom_prompt(\n", + " document_content_description, content_attr, model=\"llama3.1\"\n", + "):\n", " # Filter the attribute info based on content_attr\n", - " filter_attribute_info = tuple(ai for ai in attribute_info if ai[\"name\"] in content_attr)\n", - " \n", + " filter_attribute_info = tuple(\n", + " ai for ai in attribute_info if ai[\"name\"] in content_attr\n", + " )\n", + "\n", " # Create a custom prompt template\n", " custom_prompt = create_custom_prompt_template()\n", - " \n", + "\n", " # Prepare the input for the custom prompt\n", - " attributes = \", \".join([f\"{attr['name']}: {attr['description']}\" for attr in filter_attribute_info])\n", + " attributes = \", \".join(\n", + " [f\"{attr['name']}: {attr['description']}\" for attr in filter_attribute_info]\n", + " )\n", " # examples_formatted = \"\\n\\n\".join([f\"Input: {ex['input']}\\nOutput: {ex['structured_query']}\" for ex in examples])\n", - " \n", + "\n", " # Create a chain with the custom prompt\n", " chain = LLMChain(\n", " llm=ChatOllama(model=model),\n", @@ -1449,15 +1451,17 @@ " allowed_operators=tuple(Operator),\n", " allowed_attributes=[attr[\"name\"] for attr in filter_attribute_info],\n", " fix_invalid=True,\n", - " )\n", + " ),\n", " )\n", - " return chain.run({\n", - " \"initial_result\": structured_query, \n", - " \"metric\": 0.8, \n", - " \"question\": \"find some mushroom dataset with less than 10k size and json format\",\n", - " \"allowed_comparators\": tuple(Comparator),\n", - " \"allowed_operators\": tuple(Operator)\n", - " })" + " return chain.run(\n", + " {\n", + " \"initial_result\": structured_query,\n", + " \"metric\": 0.8,\n", + " \"question\": \"find some mushroom dataset with less than 10k size and json format\",\n", + " \"allowed_comparators\": tuple(Comparator),\n", + " \"allowed_operators\": tuple(Operator),\n", + " }\n", + " )" ] }, { @@ -1466,7 +1470,10 @@ "metadata": {}, "outputs": [], "source": [ - "create_query_structuring_chain_with_custom_prompt(document_content_description=document_content_description, content_attr=attribute_info)" + "create_query_structuring_chain_with_custom_prompt(\n", + " document_content_description=document_content_description,\n", + " content_attr=attribute_info,\n", + ")" ] }, { @@ -1513,7 +1520,14 @@ } ], "source": [ - "chain.invoke({\"query\": \"find some mushroom dataset with less than 10k size and json format\", \"initial_result\": structured_query, \"metric\":0.8, \"question\":\"find some mushroom dataset with less than 10k size and json format\"})" + "chain.invoke(\n", + " {\n", + " \"query\": \"find some mushroom dataset with less than 10k size and json format\",\n", + " \"initial_result\": structured_query,\n", + " \"metric\": 0.8,\n", + " \"question\": \"find some mushroom dataset with less than 10k size and json format\",\n", + " }\n", + ")" ] }, { @@ -1522,17 +1536,23 @@ "metadata": {}, "outputs": [], "source": [ - "def create_query_structuring_chain_with_custom_prompt(document_content_description, content_attr, model=\"llama3.1\"):\n", + "def create_query_structuring_chain_with_custom_prompt(\n", + " document_content_description, content_attr, model=\"llama3.1\"\n", + "):\n", " # Filter the attribute info based on content_attr\n", - " filter_attribute_info = tuple(ai for ai in attribute_info if ai[\"name\"] in content_attr)\n", - " \n", + " filter_attribute_info = tuple(\n", + " ai for ai in attribute_info if ai[\"name\"] in content_attr\n", + " )\n", + "\n", " # Create a custom prompt template\n", " custom_prompt = create_custom_prompt_template()\n", - " \n", + "\n", " # Prepare the input for the custom prompt\n", - " attributes = \", \".join([f\"{attr['name']}: {attr['description']}\" for attr in filter_attribute_info])\n", + " attributes = \", \".join(\n", + " [f\"{attr['name']}: {attr['description']}\" for attr in filter_attribute_info]\n", + " )\n", " # examples_formatted = \"\\n\\n\".join([f\"Input: {ex['input']}\\nOutput: {ex['structured_query']}\" for ex in examples])\n", - " \n", + "\n", " # Create a chain with the custom prompt\n", " chain = LLMChain(\n", " llm=ChatOllama(model=model),\n", @@ -1542,39 +1562,42 @@ " allowed_operators=tuple(Operator),\n", " allowed_attributes=[attr[\"name\"] for attr in filter_attribute_info],\n", " fix_invalid=True,\n", - " )\n", + " ),\n", + " )\n", + "\n", + " return chain.run(\n", + " {\n", + " \"document_contents\": document_content_description,\n", + " \"attributes\": attributes,\n", + " \"examples\": None,\n", + " }\n", " )\n", "\n", - " return chain.run({\n", - " \"document_contents\": document_content_description,\n", - " \"attributes\": attributes,\n", - " \"examples\": None\n", - " })\n", "\n", "def validate_and_retry(structured_query_output, model=\"llama3.1\"):\n", " validation_prompt = (\n", " \"Here is a structured query output: {output}\\n\"\n", " \"Please check if it is a valid JSON. If it is verbose or incorrect, fix it.\"\n", " )\n", - " \n", + "\n", " chain = LLMChain(\n", " llm=ChatOllama(model=model),\n", - " prompt=PromptTemplate(\n", - " input_variables=[\"output\"],\n", - " template=validation_prompt\n", - " )\n", + " prompt=PromptTemplate(input_variables=[\"output\"], template=validation_prompt),\n", " )\n", - " \n", + "\n", " valid_output = chain.run({\"output\": structured_query_output})\n", " return valid_output\n", "\n", + "\n", "document_description = \"Metadata of machine learning datasets including status, number of classes, number of instances, number of features, and combined information about the dataset.\"\n", "content_attributes = [\"status\", \"number_of_classes\", \"combined_information\"]\n", "\n", - "initial_output = create_query_structuring_chain_with_custom_prompt(document_description, content_attributes)\n", + "initial_output = create_query_structuring_chain_with_custom_prompt(\n", + " document_description, content_attributes\n", + ")\n", "print(initial_output)\n", "final_output = validate_and_retry(initial_output)\n", - "print(final_output)\n" + "print(final_output)" ] } ], diff --git a/useful_scripts/test_strutured_query_pipeline.ipynb b/useful_scripts/test_strutured_query_pipeline.ipynb index fa38a55..06a7444 100644 --- a/useful_scripts/test_strutured_query_pipeline.ipynb +++ b/useful_scripts/test_strutured_query_pipeline.ipynb @@ -360,9 +360,6 @@ " return json.load(file)\n", "\n", "\n", - "\n", - "\n", - "\n", "# Override the load_paths method\n", "def custom_load_paths():\n", " with open(custom_paths_json_path, \"r\") as file:\n",