From e04b4550ac58a2a115b3351daba628ed016a4f10 Mon Sep 17 00:00:00 2001 From: jrpereirajr Date: Tue, 13 Aug 2024 14:44:52 -0300 Subject: [PATCH] examples table management --- python/sqlzilla/app.py | 12 +++- python/sqlzilla/sqlzilla.py | 135 ++++++++++++++---------------------- 2 files changed, 60 insertions(+), 87 deletions(-) diff --git a/python/sqlzilla/app.py b/python/sqlzilla/app.py index ef52f9d..95abb84 100644 --- a/python/sqlzilla/app.py +++ b/python/sqlzilla/app.py @@ -30,6 +30,8 @@ st.session_state.query_result = None if 'code_text' not in st.session_state: st.session_state.code_text = '' +if 'prompt' not in st.session_state: + st.session_state.prompt = '' def db_connection_str(): user = st.session_state.user @@ -81,7 +83,7 @@ def clean_response(response): database_schema = None if (st.session_state.namespace and st.session_state.openai_api_key): - sqlzilla = SQLZilla(db_connection_str(), st.session_state.openai_api_key) + sqlzilla = SQLZilla(db_connection_str(), st.session_state.openai_api_key, database_schema) # Initial prompts for namespace and database schema try: query = """ @@ -126,6 +128,9 @@ def clean_response(response): st.session_state.query_result = pd.DataFrame(data) st.dataframe(st.session_state.query_result) + if st.button("Save on library"): + sqlzilla.add_example(st.session_state.prompt, st.session_state.code_text) + with col2: # Display chat history for message in st.session_state.chat_history: @@ -133,6 +138,7 @@ def clean_response(response): # React to user input if prompt := st.chat_input("How can I assist you?"): + st.session_state.prompt = prompt # Display user message in chat message container st.chat_message("user").markdown(prompt) # Add user message to chat history @@ -145,8 +151,8 @@ def clean_response(response): st.session_state.query = response st.session_state.code_text = response editor_dict['text'] = response - data = sqlzilla.execute_query(st.session_state.code_text) - st.session_state.query_result = pd.DataFrame(data) + # data = sqlzilla.execute_query(st.session_state.code_text) + # st.session_state.query_result = pd.DataFrame(data) st.rerun() # Display assistant response in chat message container with st.chat_message("assistant"): diff --git a/python/sqlzilla/sqlzilla.py b/python/sqlzilla/sqlzilla.py index eafb4c0..bb569bb 100644 --- a/python/sqlzilla/sqlzilla.py +++ b/python/sqlzilla/sqlzilla.py @@ -1,6 +1,7 @@ from sqlalchemy import create_engine import hashlib import pandas as pd; +import re from langchain_core.prompts import PromptTemplate, ChatPromptTemplate from langchain_core.example_selectors import SemanticSimilarityExampleSelector @@ -12,10 +13,10 @@ from langchain_iris import IRISVector class SQLZilla: - def __init__(self, connection_string, openai_api_key): + def __init__(self, connection_string, openai_api_key, schema_name="SQLUser"): self.log('criou') self.openai_api_key = openai_api_key - # self.iris_conn_str = connection_string + self.schema_name = schema_name self.engine = create_engine(connection_string) self.conn_wrapper = self.engine.connect() self.connection = self.conn_wrapper.connection @@ -27,70 +28,30 @@ def __init__(self, connection_string, openai_api_key): self.chain_model = None self.example_prompt = None self.create_chain_model() + + def create_examples_table(self): + sql = """ + CREATE TABLE IF NOT EXISTS sqlzilla.examples ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + prompt VARCHAR(255) NOT NULL, + query VARCHAR(255) NOT NULL, + schema_name VARCHAR(255) NOT NULL + ); + """ + self.execute_query(sql) def get_examples(self): - return [ - { - "input": "List all aircrafts.", - "query": "SELECT * FROM Aviation.Aircraft" - }, - { - "input": "Find all incidents for the aircraft with ID 'N12345'.", - "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')" - }, - { - "input": "List all incidents in the 'Commercial' operation type.", - "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE OperationType = 'Commercial')" - }, - { - "input": "Find the total number of incidents.", - "query": "SELECT COUNT(*) FROM Aviation.Event" - }, - { - "input": "List all incidents that occurred in 'Canada'.", - "query": "SELECT * FROM Aviation.Event WHERE LocationCountry = 'Canada'" - }, - { - "input": "How many incidents are associated with the aircraft with AircraftKey 5?", - "query": "SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5" - }, - { - "input": "Find the total number of distinct aircrafts involved in incidents.", - "query": "SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft" - }, - { - "input": "List all incidents that occurred after 5 PM.", - "query": "SELECT * FROM Aviation.Event WHERE EventTime > 1700" - }, - { - "input": "Who are the top 5 operators by the number of incidents?", - "query": "SELECT TOP 5 OperatorName, COUNT(*) AS IncidentCount FROM Aviation.Aircraft GROUP BY OperatorName ORDER BY IncidentCount DESC" - }, - { - "input": "Which incidents occurred in the year 2020?", - "query": "SELECT * FROM Aviation.Event WHERE YEAR(EventDate) = '2020'" - }, - { - "input": "What was the month with most events in the year 2020?", - "query": "SELECT TOP 1 MONTH(EventDate) EventMonth, COUNT(*) EventCount FROM Aviation.Event WHERE YEAR(EventDate) = '2020' GROUP BY MONTH(EventDate) ORDER BY EventCount DESC" - }, - { - "input": "How many crew members were involved in incidents?", - "query": "SELECT COUNT(*) FROM Aviation.Crew" - }, - { - "input": "List all incidents with detailed aircraft information for incidents that occurred in the year 2012.", - "query": "SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012" - }, - { - "input": "Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.", - "query": "SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5" - }, - { - "input": "List all crew members involved in incidents with serious injuries, along with the incident date and location.", - "query": "SELECT c.CrewNumber, c.Age, c.Sex, e.EventDate, e.LocationCity, e.LocationState FROM Aviation.Crew c JOIN Aviation.Event e ON c.EventId = e.EventId WHERE c.Injury = 'Serious'" - }, - ] + sql = "SELECT prompt, query FROM sqlzilla.examples WHERE schema_name = %s" + rows = self.execute_query(sql, [self.schema_name]) + examples = [{ + "input": row[0], + "query": row[1] + } for row in rows] + return examples + + def add_example(self, prompt, query): + sql = "INSERT INTO sqlzilla.examples (prompt, query, schema) VALUES (%s, %s, %s)" + self.execute_query(sql, [prompt, query, self.schema_name]) def __del__(self): self.log('deletou') @@ -207,24 +168,23 @@ def filter_not_in_collection(self, collection_name, docs_array, ids_array): return list(zip(*filtered)) or ([], []) def schema_context_management(self, schema): - table_def = self.get_table_definitions_array(schema) - self.table_df = pd.DataFrame(data=table_def, columns=["col_def"]) - self.table_df["id"] = self.table_df.index + 1 - loader = DataFrameLoader(self.table_df, page_content_column="col_def") - documents = loader.load() - text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n") - self.tables_docs = text_splitter.split_documents(documents) - new_tables_docs, tables_docs_ids = self.filter_not_in_collection( - "sql_tables", - self.tables_docs, - self.get_ids_from_string_array([x.page_content for x in self.tables_docs]) - ) - self.tables_docs_ids = tables_docs_ids if self.tables_vector_store is None: + table_def = self.get_table_definitions_array(schema) + self.table_df = pd.DataFrame(data=table_def, columns=["col_def"]) + self.table_df["id"] = self.table_df.index + 1 + loader = DataFrameLoader(self.table_df, page_content_column="col_def") + documents = loader.load() + text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n") + self.tables_docs = text_splitter.split_documents(documents) + new_tables_docs, tables_docs_ids = self.filter_not_in_collection( + "sql_tables", + self.tables_docs, + self.get_ids_from_string_array([x.page_content for x in self.tables_docs]) + ) + self.tables_docs_ids = tables_docs_ids self.tables_vector_store = IRISVector.from_documents( embedding = OpenAIEmbeddings(openai_api_key=self.openai_api_key), documents = self.tables_docs, - # connection_string= self.iris_conn_str, connection=self.conn_wrapper, collection_name="sql_tables", ids=self.tables_docs_ids @@ -243,7 +203,6 @@ def schema_context_management(self, schema): IRISVector, k=5, input_keys=["input"], - # connection_string=self.iris_conn_str, connection=self.conn_wrapper, collection_name="sql_samples", ids=sql_samples_ids @@ -306,14 +265,22 @@ def prompt(self, input): "top_k": self.context["top_k"], "table_info": self.context["table_info"], "examples_value": self.context["examples_value"], - "input": input + "input": self.context["input"] }) return response - def execute_query(self, query): + def execute_query(self, query, params=None): cursor = self.connection.cursor() # Execute the query - cursor.execute(query) + cursor.execute(query, params) - # Fetch the results - return cursor.fetchall() + if re.search(r"\s*SELECT\s+", query, re.IGNORECASE): + # Fetch the results + return cursor.fetchall() + elif re.search(r"\s*INSERT\s+", query, re.IGNORECASE): + self.connection.commit() + elif re.search(r"\s*UPDATE\s+", query, re.IGNORECASE): + self.connection.commit() + elif re.search(r"\s*DELETE\s+", query, re.IGNORECASE): + self.connection.commit() + return None