Skip to content

Commit

Permalink
examples table management
Browse files Browse the repository at this point in the history
  • Loading branch information
jrpereirajr committed Aug 13, 2024
1 parent f423256 commit e04b455
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 87 deletions.
12 changes: 9 additions & 3 deletions python/sqlzilla/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
st.session_state.query_result = None
if 'code_text' not in st.session_state:
st.session_state.code_text = ''
if 'prompt' not in st.session_state:
st.session_state.prompt = ''

def db_connection_str():
user = st.session_state.user
Expand Down Expand Up @@ -81,7 +83,7 @@ def clean_response(response):

database_schema = None
if (st.session_state.namespace and st.session_state.openai_api_key):
sqlzilla = SQLZilla(db_connection_str(), st.session_state.openai_api_key)
sqlzilla = SQLZilla(db_connection_str(), st.session_state.openai_api_key, database_schema)
# Initial prompts for namespace and database schema
try:
query = """
Expand Down Expand Up @@ -126,13 +128,17 @@ def clean_response(response):
st.session_state.query_result = pd.DataFrame(data)
st.dataframe(st.session_state.query_result)

if st.button("Save on library"):
sqlzilla.add_example(st.session_state.prompt, st.session_state.code_text)

with col2:
# Display chat history
for message in st.session_state.chat_history:
st.chat_message(message["role"]).markdown(message["content"])

# React to user input
if prompt := st.chat_input("How can I assist you?"):
st.session_state.prompt = prompt
# Display user message in chat message container
st.chat_message("user").markdown(prompt)
# Add user message to chat history
Expand All @@ -145,8 +151,8 @@ def clean_response(response):
st.session_state.query = response
st.session_state.code_text = response
editor_dict['text'] = response
data = sqlzilla.execute_query(st.session_state.code_text)
st.session_state.query_result = pd.DataFrame(data)
# data = sqlzilla.execute_query(st.session_state.code_text)
# st.session_state.query_result = pd.DataFrame(data)
st.rerun()
# Display assistant response in chat message container
with st.chat_message("assistant"):
Expand Down
135 changes: 51 additions & 84 deletions python/sqlzilla/sqlzilla.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from sqlalchemy import create_engine
import hashlib
import pandas as pd;
import re

from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
Expand All @@ -12,10 +13,10 @@
from langchain_iris import IRISVector

class SQLZilla:
def __init__(self, connection_string, openai_api_key):
def __init__(self, connection_string, openai_api_key, schema_name="SQLUser"):
self.log('criou')
self.openai_api_key = openai_api_key
# self.iris_conn_str = connection_string
self.schema_name = schema_name
self.engine = create_engine(connection_string)
self.conn_wrapper = self.engine.connect()
self.connection = self.conn_wrapper.connection
Expand All @@ -27,70 +28,30 @@ def __init__(self, connection_string, openai_api_key):
self.chain_model = None
self.example_prompt = None
self.create_chain_model()

def create_examples_table(self):
sql = """
CREATE TABLE IF NOT EXISTS sqlzilla.examples (
id INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
prompt VARCHAR(255) NOT NULL,
query VARCHAR(255) NOT NULL,
schema_name VARCHAR(255) NOT NULL
);
"""
self.execute_query(sql)

def get_examples(self):
return [
{
"input": "List all aircrafts.",
"query": "SELECT * FROM Aviation.Aircraft"
},
{
"input": "Find all incidents for the aircraft with ID 'N12345'.",
"query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')"
},
{
"input": "List all incidents in the 'Commercial' operation type.",
"query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE OperationType = 'Commercial')"
},
{
"input": "Find the total number of incidents.",
"query": "SELECT COUNT(*) FROM Aviation.Event"
},
{
"input": "List all incidents that occurred in 'Canada'.",
"query": "SELECT * FROM Aviation.Event WHERE LocationCountry = 'Canada'"
},
{
"input": "How many incidents are associated with the aircraft with AircraftKey 5?",
"query": "SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5"
},
{
"input": "Find the total number of distinct aircrafts involved in incidents.",
"query": "SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft"
},
{
"input": "List all incidents that occurred after 5 PM.",
"query": "SELECT * FROM Aviation.Event WHERE EventTime > 1700"
},
{
"input": "Who are the top 5 operators by the number of incidents?",
"query": "SELECT TOP 5 OperatorName, COUNT(*) AS IncidentCount FROM Aviation.Aircraft GROUP BY OperatorName ORDER BY IncidentCount DESC"
},
{
"input": "Which incidents occurred in the year 2020?",
"query": "SELECT * FROM Aviation.Event WHERE YEAR(EventDate) = '2020'"
},
{
"input": "What was the month with most events in the year 2020?",
"query": "SELECT TOP 1 MONTH(EventDate) EventMonth, COUNT(*) EventCount FROM Aviation.Event WHERE YEAR(EventDate) = '2020' GROUP BY MONTH(EventDate) ORDER BY EventCount DESC"
},
{
"input": "How many crew members were involved in incidents?",
"query": "SELECT COUNT(*) FROM Aviation.Crew"
},
{
"input": "List all incidents with detailed aircraft information for incidents that occurred in the year 2012.",
"query": "SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012"
},
{
"input": "Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.",
"query": "SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5"
},
{
"input": "List all crew members involved in incidents with serious injuries, along with the incident date and location.",
"query": "SELECT c.CrewNumber, c.Age, c.Sex, e.EventDate, e.LocationCity, e.LocationState FROM Aviation.Crew c JOIN Aviation.Event e ON c.EventId = e.EventId WHERE c.Injury = 'Serious'"
},
]
sql = "SELECT prompt, query FROM sqlzilla.examples WHERE schema_name = %s"
rows = self.execute_query(sql, [self.schema_name])
examples = [{
"input": row[0],
"query": row[1]
} for row in rows]
return examples

def add_example(self, prompt, query):
sql = "INSERT INTO sqlzilla.examples (prompt, query, schema) VALUES (%s, %s, %s)"
self.execute_query(sql, [prompt, query, self.schema_name])

def __del__(self):
self.log('deletou')
Expand Down Expand Up @@ -207,24 +168,23 @@ def filter_not_in_collection(self, collection_name, docs_array, ids_array):
return list(zip(*filtered)) or ([], [])

def schema_context_management(self, schema):
table_def = self.get_table_definitions_array(schema)
self.table_df = pd.DataFrame(data=table_def, columns=["col_def"])
self.table_df["id"] = self.table_df.index + 1
loader = DataFrameLoader(self.table_df, page_content_column="col_def")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n")
self.tables_docs = text_splitter.split_documents(documents)
new_tables_docs, tables_docs_ids = self.filter_not_in_collection(
"sql_tables",
self.tables_docs,
self.get_ids_from_string_array([x.page_content for x in self.tables_docs])
)
self.tables_docs_ids = tables_docs_ids
if self.tables_vector_store is None:
table_def = self.get_table_definitions_array(schema)
self.table_df = pd.DataFrame(data=table_def, columns=["col_def"])
self.table_df["id"] = self.table_df.index + 1
loader = DataFrameLoader(self.table_df, page_content_column="col_def")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n")
self.tables_docs = text_splitter.split_documents(documents)
new_tables_docs, tables_docs_ids = self.filter_not_in_collection(
"sql_tables",
self.tables_docs,
self.get_ids_from_string_array([x.page_content for x in self.tables_docs])
)
self.tables_docs_ids = tables_docs_ids
self.tables_vector_store = IRISVector.from_documents(
embedding = OpenAIEmbeddings(openai_api_key=self.openai_api_key),
documents = self.tables_docs,
# connection_string= self.iris_conn_str,
connection=self.conn_wrapper,
collection_name="sql_tables",
ids=self.tables_docs_ids
Expand All @@ -243,7 +203,6 @@ def schema_context_management(self, schema):
IRISVector,
k=5,
input_keys=["input"],
# connection_string=self.iris_conn_str,
connection=self.conn_wrapper,
collection_name="sql_samples",
ids=sql_samples_ids
Expand Down Expand Up @@ -306,14 +265,22 @@ def prompt(self, input):
"top_k": self.context["top_k"],
"table_info": self.context["table_info"],
"examples_value": self.context["examples_value"],
"input": input
"input": self.context["input"]
})
return response

def execute_query(self, query):
def execute_query(self, query, params=None):
cursor = self.connection.cursor()
# Execute the query
cursor.execute(query)
cursor.execute(query, params)

# Fetch the results
return cursor.fetchall()
if re.search(r"\s*SELECT\s+", query, re.IGNORECASE):
# Fetch the results
return cursor.fetchall()
elif re.search(r"\s*INSERT\s+", query, re.IGNORECASE):
self.connection.commit()
elif re.search(r"\s*UPDATE\s+", query, re.IGNORECASE):
self.connection.commit()
elif re.search(r"\s*DELETE\s+", query, re.IGNORECASE):
self.connection.commit()
return None

0 comments on commit e04b455

Please sign in to comment.