Skip to content

Commit

Permalink
🚧
Browse files Browse the repository at this point in the history
  • Loading branch information
henryhamon committed Jul 29, 2024
1 parent 90204ba commit 19367d8
Showing 1 changed file with 143 additions and 1 deletion.
144 changes: 143 additions & 1 deletion python/sqlzilla/sqlzilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,71 @@ class SQLZilla:
def __init__(self, engine, cnx):
self.engine = engine
self.cnx = cnx
self.context = {}
self.context["top_k"] = 3
self.examples = [
{
"input": "List all aircrafts.",
"query": "SELECT * FROM Aviation.Aircraft"
},
{
"input": "Find all incidents for the aircraft with ID 'N12345'.",
"query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')"
},
{
"input": "List all incidents in the 'Commercial' operation type.",
"query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE OperationType = 'Commercial')"
},
{
"input": "Find the total number of incidents.",
"query": "SELECT COUNT(*) FROM Aviation.Event"
},
{
"input": "List all incidents that occurred in 'Canada'.",
"query": "SELECT * FROM Aviation.Event WHERE LocationCountry = 'Canada'"
},
{
"input": "How many incidents are associated with the aircraft with AircraftKey 5?",
"query": "SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5"
},
{
"input": "Find the total number of distinct aircrafts involved in incidents.",
"query": "SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft"
},
{
"input": "List all incidents that occurred after 5 PM.",
"query": "SELECT * FROM Aviation.Event WHERE EventTime > 1700"
},
{
"input": "Who are the top 5 operators by the number of incidents?",
"query": "SELECT TOP 5 OperatorName, COUNT(*) AS IncidentCount FROM Aviation.Aircraft GROUP BY OperatorName ORDER BY IncidentCount DESC"
},
{
"input": "Which incidents occurred in the year 2020?",
"query": "SELECT * FROM Aviation.Event WHERE YEAR(EventDate) = '2020'"
},
{
"input": "What was the month with most events in the year 2020?",
"query": "SELECT TOP 1 MONTH(EventDate) EventMonth, COUNT(*) EventCount FROM Aviation.Event WHERE YEAR(EventDate) = '2020' GROUP BY MONTH(EventDate) ORDER BY EventCount DESC"
},
{
"input": "How many crew members were involved in incidents?",
"query": "SELECT COUNT(*) FROM Aviation.Crew"
},
{
"input": "List all incidents with detailed aircraft information for incidents that occurred in the year 2012.",
"query": "SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012"
},
{
"input": "Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.",
"query": "SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5"
},
{
"input": "List all crew members involved in incidents with serious injuries, along with the incident date and location.",
"query": "SELECT c.CrewNumber, c.Age, c.Sex, e.EventDate, e.LocationCity, e.LocationState FROM Aviation.Crew c JOIN Aviation.Event e ON c.EventId = e.EventId WHERE c.Injury = 'Serious'"
},
]


def get_table_definitions_array(self, schema, table=None):
cursor = self.cnx.cursor()
Expand Down Expand Up @@ -119,4 +184,81 @@ def exists_in_db(self, collection_name, id):

def filter_not_in_collection(self, collection_name, docs_array, ids_array):
filtered = [x for x in zip(docs_array, ids_array) if not self.exists_in_db(collection_name, x[1])]
return list(zip(*filtered)) or ([], [])
return list(zip(*filtered)) or ([], [])

def schema_context_management(self, schema):
table_def = self.get_table_definitions_array(schema)
self.table_df = pd.DataFrame(data=table_def, columns=["col_def"])
self.table_df["id"] = self.table_df.index + 1
loader = DataFrameLoader(self.table_df, page_content_column="col_def")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n")
self.tables_docs = text_splitter.split_documents(documents)
new_tables_docs, tables_docs_ids = self.filter_not_in_collection(
"sql_tables",
self.tables_docs,
self.get_ids_from_string_array([x.page_content for x in tables_docs])
)
self.tables_docs_ids = tables_docs_ids


def prompt(self, input):
db = IRISVector.from_documents(
embedding = OpenAIEmbeddings(),
documents = self.tables_docs,
connection_string=iris_conn_str,
collection_name="sql_tables",
ids=self.tables_docs_ids
)
relevant_tables_docs = db.similarity_search(self.context["input"])
relevant_tables_docs_indices = [x.metadata["id"] for x in relevant_tables_docs]
indices = self.table_df["id"].isin(relevant_tables_docs_indices)
relevant_tables_array = [x for x in self.table_df[indices]["col_def"]]
self.context["table_info"] = "\n\n".join(relevant_tables_array)
new_sql_samples, sql_samples_ids = self.ilter_not_in_collection(
"sql_samples",
self.examples,
self.get_ids_from_string_array([x['input'] for x in self.examples])
)
iris_sql_template = """
You are an InterSystems IRIS expert. Given an input question, first create a syntactically correct InterSystems IRIS query to run and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the TOP clause as per InterSystems IRIS. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in single quotes ('') to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CAST(CURRENT_DATE as date) function to get the current date, if the question involves "today".
"""
tables_prompt_template = """
Only use the following tables:
{table_info}
"""
prompt_sql_few_shots_template = """
Below are a number of examples of questions and their corresponding SQL queries.
{examples_value}
"""
example_prompt_template = "User input: {input}\nSQL query: {query}"
example_prompt = PromptTemplate.from_template(example_prompt_template)
user_prompt = "\n"+example_prompt.invoke({"input": "{input}", "query": ""}).to_string()
prompt = (
ChatPromptTemplate.from_messages([("system", iris_sql_template)])
+ ChatPromptTemplate.from_messages([("system", tables_prompt_template)])
+ ChatPromptTemplate.from_messages([("system", prompt_sql_few_shots_template)])
+ ChatPromptTemplate.from_messages([("human", user_prompt)])
)
prompt_value = prompt.invoke({
"top_k": self.context["top_k"],
"table_info": self.context["table_info"],
"examples_value": self.context["examples_value"],
"input": input
})

model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
output_parser = StrOutputParser()
chain_model = prompt | model | output_parser
response = chain_model.invoke({
"top_k": self.context["top_k"],
"table_info": self.context["table_info"],
"examples_value": self.context["examples_value"],
"input": input
})
return response

0 comments on commit 19367d8

Please sign in to comment.