diff --git a/.gitignore b/.gitignore index c29aef2..d907eb2 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ iris-main.log .env .git tmp +__pycache__ \ No newline at end of file diff --git a/python/sqlzilla/app copy 2.py b/python/sqlzilla/app copy 2.py new file mode 100644 index 0000000..7b74fae --- /dev/null +++ b/python/sqlzilla/app copy 2.py @@ -0,0 +1,168 @@ +import streamlit as st +from code_editor import code_editor +import requests +from sqlalchemy import create_engine +import pandas as pd +from dotenv import load_dotenv +import os +from sqlalchemy import create_engine +from sqlzilla import SQLZilla + +# Load environment variables from .env file +load_dotenv() + +# Initialize session state +if 'hostname' not in st.session_state: + st.session_state.hostname = 'sqlzilla-iris-1' +if 'user' not in st.session_state: + st.session_state.user = '_system' +if 'pwd' not in st.session_state: + st.session_state.pwd = 'SYS' +if 'port' not in st.session_state: + st.session_state.port = '1972' +if 'namespace' not in st.session_state: + st.session_state.namespace = "IRISAPP" +if 'openai_api_key' not in st.session_state: + st.session_state.openai_api_key = os.getenv('OPENAI_API_KEY', '') +if 'chat_history' not in st.session_state: + st.session_state.chat_history = [{"role": "assistant", "content": "I'm SQLZilla, your friendly AI SQL helper.\n Ask me anything, from basic queries to complex optimizations."}] +if 'query_result' not in st.session_state: + st.session_state.query_result = None +if 'code_text' not in st.session_state: + st.session_state.code_text = '' + +def db_connection_str(): + user = st.session_state.user + pwd = st.session_state.pwd + host = st.session_state.hostname + prt = st.session_state.port + ns = st.session_state.namespace + return f"iris://{user}:{pwd}@{host}:{prt}/{ns}" + +def assistant_interaction(sqlzilla, prompt): + response = sqlzilla.prompt(prompt) + response = clean_response(response) + st.session_state.chat_history.append({"role": "assistant", "content": response}) + return response + +def clean_response(response): + if response.startswith("```sql"): + response = response[6:-3] + return response + +left_co, cent_co, last_co = st.columns(3) +with cent_co: + try: + st.image("small_logo.png", use_column_width=True) + except Exception as e: + st.error(f"Erro ao carregar a imagem: {str(e)}") + st.write("SQLZilla") + +# Authentication configuration +if not st.session_state.openai_api_key: + st.warning("Please provide your API key to proceed.") + st.session_state.openai_api_key = st.text_input("API Key", type="password") +else: + if st.button("Config"): + host = st.text_input("Hostname", value=st.session_state.hostname) + username = st.text_input("Username", value=st.session_state.user) + password = st.text_input("Password", value=st.session_state.pwd, type="password") + namespace = st.text_input("Namespace", value=st.session_state.namespace) + port = st.text_input("Port", value=st.session_state.port) + api_key = st.text_input("API Key", value=st.session_state.openai_api_key, type="password") + if st.button("Save"): + st.session_state.hostname = host + st.session_state.port = port + st.session_state.pwd = password + st.session_state.user = username + st.session_state.namespace = namespace + st.session_state.openai_api_key = api_key + st.success("Configuration updated!") + +# from sqlalchemy import create_engine + +# def log(msg): +# import os +# os.write(1, f"{msg}\n".encode()) + +# engine = create_engine(db_connection_str(), pool_size=1, max_overflow=0) +# cnx = engine.connect().connection +# log("Connection established") + +database_schema = None +if (st.session_state.namespace and st.session_state.openai_api_key): + + sqlzilla = SQLZilla(db_connection_str(), st.session_state.openai_api_key) + # Initial prompts for namespace and database schema + try: + query = """ + SELECT SCHEMA_NAME + FROM INFORMATION_SCHEMA.SCHEMATA + """ + rows = sqlzilla.execute_query(query) + options = [row[0] for row in rows or []] + database_schema = st.selectbox( + 'Enter Database Schema', + options, + index=None, + placeholder="Select database schema...", + ) + except: + database_schema = st.text_input('Enter Database Schema') + st.warning('Was not possible to retrieve database schemas. Please provide it manually.') + +if (st.session_state.namespace and database_schema and st.session_state.openai_api_key): + context = sqlzilla.schema_context_management(database_schema) + + # Layout for the page + col1, col2 = st.columns(2) + + with col1: + editor_btn = [ { + "name": "Execute", "feather": "Play", + "primary": True,"hasText": True, "alwaysOn": True, + "showWithIcon": True,"commands": ["submit"], + "style": { + "bottom": "0.44rem", + "right": "0.4rem" + } + },] + editor_dict = code_editor(st.session_state.code_text, lang="sql", height=[10, 100], shortcuts="vscode", options={"placeholder":"Add your SQL here to test...", "showLineNumbers":True}, buttons=editor_btn) + + if editor_dict['type'] == "submit": + st.session_state.code_text = editor_dict['text'] + data = sqlzilla.execute_query(st.session_state.code_text) + # Display query result as dataframe + if (data is not None): + st.session_state.query_result = pd.DataFrame(data) + st.dataframe(st.session_state.query_result) + + with col2: + # Display chat history + for message in st.session_state.chat_history: + st.chat_message(message["role"]).markdown(message["content"]) + + # React to user input + if prompt := st.chat_input("How can I assist you?"): + # Display user message in chat message container + st.chat_message("user").markdown(prompt) + # Add user message to chat history + st.session_state.chat_history.append({"role": "user", "content": prompt}) + + response = assistant_interaction(sqlzilla, prompt) + + # Check if the response contains SQL code and update the editor + if "SELECT" in response.upper(): + st.session_state.query = response + st.session_state.code_text = response + editor_dict['text'] = response + data = sqlzilla.execute_query(st.session_state.code_text) + st.session_state.query_result = pd.DataFrame(data) + st.rerun() + # Display assistant response in chat message container + with st.chat_message("assistant"): + st.markdown(response) + # Add assistant response to chat history + st.session_state.chat_history.append({"role": "assistant", "content": response}) +else: + st.warning('Please select a database schema to proceed.') diff --git a/python/sqlzilla/app copy.py b/python/sqlzilla/app copy.py new file mode 100644 index 0000000..cb15742 --- /dev/null +++ b/python/sqlzilla/app copy.py @@ -0,0 +1,172 @@ +import streamlit as st +from code_editor import code_editor +import requests +from sqlalchemy import create_engine +import pandas as pd +from dotenv import load_dotenv +import os +from sqlalchemy import create_engine +from sqlzilla import SQLZilla + +# Load environment variables from .env file +load_dotenv() + +# Initialize session state +if '_cnx' not in st.session_state: + st.session_state._cnx = None +if 'hostname' not in st.session_state: + st.session_state.hostname = 'sqlzilla-iris-1' +if 'user' not in st.session_state: + st.session_state.user = '_system' +if 'pwd' not in st.session_state: + st.session_state.pwd = 'SYS' +if 'port' not in st.session_state: + st.session_state.port = '1972' +if 'namespace' not in st.session_state: + st.session_state.namespace = "IRISAPP" +if 'openai_api_key' not in st.session_state: + st.session_state.openai_api_key = os.getenv('OPENAI_API_KEY', '') +if 'chat_history' not in st.session_state: + st.session_state.chat_history = [{"role": "assistant", "content": "I'm SQLZilla, your friendly AI SQL helper.\n Ask me anything, from basic queries to complex optimizations."}] +if 'query_result' not in st.session_state: + st.session_state.query_result = None +if 'code_text' not in st.session_state: + st.session_state.code_text = '' + +def db_connection_str(): + user = st.session_state.user + pwd = st.session_state.pwd + host = st.session_state.hostname + prt = st.session_state.port + ns = st.session_state.namespace + return f"iris://{user}:{pwd}@{host}:{prt}/{ns}" + +def assistant_interaction(sqlzilla, prompt): + response = sqlzilla.prompt(prompt) + response = clean_response(response) + st.session_state.chat_history.append({"role": "assistant", "content": response}) + return response + +def clean_response(response): + if response.startswith("```sql"): + response = response[6:-3] + return response + +left_co, cent_co, last_co = st.columns(3) +with cent_co: + try: + st.image("small_logo.png", use_column_width=True) + except Exception as e: + st.error(f"Erro ao carregar a imagem: {str(e)}") + st.write("SQLZilla") + +# Authentication configuration +if not st.session_state.openai_api_key: + st.warning("Please provide your API key to proceed.") + st.session_state.openai_api_key = st.text_input("API Key", type="password") +else: + if st.button("Config"): + host = st.text_input("Hostname", value=st.session_state.hostname) + username = st.text_input("Username", value=st.session_state.user) + password = st.text_input("Password", value=st.session_state.pwd, type="password") + namespace = st.text_input("Namespace", value=st.session_state.namespace) + port = st.text_input("Port", value=st.session_state.port) + api_key = st.text_input("API Key", value=st.session_state.openai_api_key, type="password") + if st.button("Save"): + st.session_state.hostname = host + st.session_state.port = port + st.session_state.pwd = password + st.session_state.user = username + st.session_state.namespace = namespace + st.session_state.openai_api_key = api_key + st.success("Configuration updated!") + +database_schema = None +if (st.session_state.namespace and st.session_state.openai_api_key): + sqlzilla = SQLZilla(db_connection_str(), st.session_state.openai_api_key, st.session_state) + # Initial prompts for namespace and database schema + try: + query = """ + SELECT SCHEMA_NAME + FROM INFORMATION_SCHEMA.SCHEMATA + """ + rows = sqlzilla.execute_query(query) + options = [row[0] for row in rows or []] + database_schema = st.selectbox( + 'Enter Database Schema', + options, + index=None, + placeholder="Select database schema...", + ) + except: + database_schema = st.text_input('Enter Database Schema') + st.warning('Was not possible to retrieve database schemas. Please provide it manually.') + +if (st.session_state.namespace and database_schema and st.session_state.openai_api_key): + context = sqlzilla.schema_context_management(database_schema) + + # Layout for the page + col1, col2 = st.columns(2) + + with col1: + editor_btn = [ { + # "name": "Save on library", "feather": "Save", + # "primary": False,"hasText": True, "alwaysOn": True, + # "showWithIcon": True,"commands": ["save_lib"], + # "style": { + # "bottom": "0.44rem", + # "left": "0.4rem" + # } + # },{ + "name": "Execute", "feather": "Play", + "primary": True,"hasText": True, "alwaysOn": True, + "showWithIcon": True,"commands": ["submit"], + "style": { + "bottom": "0.44rem", + "right": "0.4rem" + } + },] + editor_dict = code_editor(st.session_state.code_text, lang="sql", height=[10, 100], shortcuts="vscode", options={"placeholder":"Add your SQL here to test...", "showLineNumbers":True}, buttons=editor_btn) + + if editor_dict['type'] == "submit": + st.session_state.code_text = editor_dict['text'] + data = sqlzilla.execute_query(st.session_state.code_text) + # Display query result as dataframe + if (data is not None): + st.session_state.query_result = pd.DataFrame(data) + st.dataframe(st.session_state.query_result) + elif editor_dict['type'] == "save_lib": + pass + + with col2: + # Display chat history + for message in st.session_state.chat_history: + st.chat_message(message["role"]).markdown(message["content"]) + + # React to user input + if prompt := st.chat_input("How can I assist you?"): + # Display user message in chat message container + st.chat_message("user").markdown(prompt) + # Add user message to chat history + st.session_state.chat_history.append({"role": "user", "content": prompt}) + + response = assistant_interaction(sqlzilla, prompt) + + # Check if the response contains SQL code and update the editor + if "SELECT" in response.upper(): + st.session_state.query = response + st.session_state.code_text = response + editor_dict['text'] = response + try: + data = sqlzilla.execute_query(st.session_state.code_text) + st.session_state.query_result = pd.DataFrame(data) + except Exception as e: + st.error(f"Error: {str(e)}") + st.rerun() + # Display assistant response in chat message container + with st.chat_message("assistant"): + st.markdown(response) + # Add assistant response to chat history + st.session_state.chat_history.append({"role": "assistant", "content": response}) +else: + st.warning('Please select a database schema to proceed.') diff --git a/python/sqlzilla/app-main.py b/python/sqlzilla/app-main.py new file mode 100644 index 0000000..ef52f9d --- /dev/null +++ b/python/sqlzilla/app-main.py @@ -0,0 +1,157 @@ +import streamlit as st +from code_editor import code_editor +import requests +from sqlalchemy import create_engine +import pandas as pd +from dotenv import load_dotenv +import os +from sqlalchemy import create_engine +from sqlzilla import SQLZilla + +# Load environment variables from .env file +load_dotenv() + +# Initialize session state +if 'hostname' not in st.session_state: + st.session_state.hostname = 'sqlzilla-iris-1' +if 'user' not in st.session_state: + st.session_state.user = '_system' +if 'pwd' not in st.session_state: + st.session_state.pwd = 'SYS' +if 'port' not in st.session_state: + st.session_state.port = '1972' +if 'namespace' not in st.session_state: + st.session_state.namespace = "IRISAPP" +if 'openai_api_key' not in st.session_state: + st.session_state.openai_api_key = os.getenv('OPENAI_API_KEY', '') +if 'chat_history' not in st.session_state: + st.session_state.chat_history = [{"role": "assistant", "content": "I'm SQLZilla, your friendly AI SQL helper.\n Ask me anything, from basic queries to complex optimizations."}] +if 'query_result' not in st.session_state: + st.session_state.query_result = None +if 'code_text' not in st.session_state: + st.session_state.code_text = '' + +def db_connection_str(): + user = st.session_state.user + pwd = st.session_state.pwd + host = st.session_state.hostname + prt = st.session_state.port + ns = st.session_state.namespace + return f"iris://{user}:{pwd}@{host}:{prt}/{ns}" + +def assistant_interaction(sqlzilla, prompt): + response = sqlzilla.prompt(prompt) + response = clean_response(response) + st.session_state.chat_history.append({"role": "assistant", "content": response}) + return response + +def clean_response(response): + if response.startswith("```sql"): + response = response[6:-3] + return response + +left_co, cent_co, last_co = st.columns(3) +with cent_co: + try: + st.image("small_logo.png", use_column_width=True) + except Exception as e: + st.error(f"Erro ao carregar a imagem: {str(e)}") + st.write("SQLZilla") + +# Authentication configuration +if not st.session_state.openai_api_key: + st.warning("Please provide your API key to proceed.") + st.session_state.openai_api_key = st.text_input("API Key", type="password") +else: + if st.button("Config"): + host = st.text_input("Hostname", value=st.session_state.hostname) + username = st.text_input("Username", value=st.session_state.user) + password = st.text_input("Password", value=st.session_state.pwd, type="password") + namespace = st.text_input("Namespace", value=st.session_state.namespace) + port = st.text_input("Port", value=st.session_state.port) + api_key = st.text_input("API Key", value=st.session_state.openai_api_key, type="password") + if st.button("Save"): + st.session_state.hostname = host + st.session_state.port = port + st.session_state.pwd = password + st.session_state.user = username + st.session_state.namespace = namespace + st.session_state.openai_api_key = api_key + st.success("Configuration updated!") + +database_schema = None +if (st.session_state.namespace and st.session_state.openai_api_key): + sqlzilla = SQLZilla(db_connection_str(), st.session_state.openai_api_key) + # Initial prompts for namespace and database schema + try: + query = """ + SELECT SCHEMA_NAME + FROM INFORMATION_SCHEMA.SCHEMATA + """ + rows = sqlzilla.execute_query(query) + options = [row[0] for row in rows or []] + database_schema = st.selectbox( + 'Enter Database Schema', + options, + index=None, + placeholder="Select database schema...", + ) + except: + database_schema = st.text_input('Enter Database Schema') + st.warning('Was not possible to retrieve database schemas. Please provide it manually.') + +if (st.session_state.namespace and database_schema and st.session_state.openai_api_key): + context = sqlzilla.schema_context_management(database_schema) + + # Layout for the page + col1, col2 = st.columns(2) + + with col1: + editor_btn = [ { + "name": "Execute", "feather": "Play", + "primary": True,"hasText": True, "alwaysOn": True, + "showWithIcon": True,"commands": ["submit"], + "style": { + "bottom": "0.44rem", + "right": "0.4rem" + } + },] + editor_dict = code_editor(st.session_state.code_text, lang="sql", height=[10, 100], shortcuts="vscode", options={"placeholder":"Add your SQL here to test...", "showLineNumbers":True}, buttons=editor_btn) + + if editor_dict['type'] == "submit": + st.session_state.code_text = editor_dict['text'] + data = sqlzilla.execute_query(st.session_state.code_text) + # Display query result as dataframe + if (data is not None): + st.session_state.query_result = pd.DataFrame(data) + st.dataframe(st.session_state.query_result) + + with col2: + # Display chat history + for message in st.session_state.chat_history: + st.chat_message(message["role"]).markdown(message["content"]) + + # React to user input + if prompt := st.chat_input("How can I assist you?"): + # Display user message in chat message container + st.chat_message("user").markdown(prompt) + # Add user message to chat history + st.session_state.chat_history.append({"role": "user", "content": prompt}) + + response = assistant_interaction(sqlzilla, prompt) + + # Check if the response contains SQL code and update the editor + if "SELECT" in response.upper(): + st.session_state.query = response + st.session_state.code_text = response + editor_dict['text'] = response + data = sqlzilla.execute_query(st.session_state.code_text) + st.session_state.query_result = pd.DataFrame(data) + st.rerun() + # Display assistant response in chat message container + with st.chat_message("assistant"): + st.markdown(response) + # Add assistant response to chat history + st.session_state.chat_history.append({"role": "assistant", "content": response}) +else: + st.warning('Please select a database schema to proceed.') diff --git a/python/sqlzilla/sqlzilla copy 2.py b/python/sqlzilla/sqlzilla copy 2.py new file mode 100644 index 0000000..a14352b --- /dev/null +++ b/python/sqlzilla/sqlzilla copy 2.py @@ -0,0 +1,324 @@ +from sqlalchemy import create_engine +import hashlib +import pandas as pd; + +from langchain_core.prompts import PromptTemplate, ChatPromptTemplate +from langchain_core.example_selectors import SemanticSimilarityExampleSelector +from langchain_openai import OpenAIEmbeddings, ChatOpenAI +from langchain.docstore.document import Document +from langchain_community.document_loaders import DataFrameLoader +from langchain.text_splitter import CharacterTextSplitter +from langchain_core.output_parsers import StrOutputParser +from langchain_iris import IRISVector + +class SQLZilla: + def __init__(self, connection_string, openai_api_key, cnx=None): + self.log('criou') + self.openai_api_key = openai_api_key + self.iris_conn_str = connection_string + self.engine = None + self.cnx = cnx + self.context = {} + self.context["top_k"] = 3 + self.tables_vector_store = None + self.example_selector = None + self.chain_model = None + self.example_prompt = None + self.create_chain_model() + + def __del__(self): + # self.get_connection().close() + # self.engine.dispose() + # self.cnx = None + self.log('deletou') + + def log(self, msg): + import os + os.write(1, f"{msg}\n".encode()) + + def get_connection(self): + # if self.cnx is None: + # self.engine = create_engine(self.iris_conn_str) + # self.cnx = self.engine.connect().connection + if self.engine is None: + self.engine = create_engine(self.iris_conn_str) + if not self.cnx is None: + self.cnx.close() + self.log("connection closed") + self.cnx = self.engine.connect().connection + self.log("connection opened") + return self.cnx + + def get_examples(self): + return [ + { + "input": "List all aircrafts.", + "query": "SELECT * FROM Aviation.Aircraft" + }, + { + "input": "Find all incidents for the aircraft with ID 'N12345'.", + "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')" + }, + { + "input": "List all incidents in the 'Commercial' operation type.", + "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE OperationType = 'Commercial')" + }, + { + "input": "Find the total number of incidents.", + "query": "SELECT COUNT(*) FROM Aviation.Event" + }, + { + "input": "List all incidents that occurred in 'Canada'.", + "query": "SELECT * FROM Aviation.Event WHERE LocationCountry = 'Canada'" + }, + { + "input": "How many incidents are associated with the aircraft with AircraftKey 5?", + "query": "SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5" + }, + { + "input": "Find the total number of distinct aircrafts involved in incidents.", + "query": "SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft" + }, + { + "input": "List all incidents that occurred after 5 PM.", + "query": "SELECT * FROM Aviation.Event WHERE EventTime > 1700" + }, + { + "input": "Who are the top 5 operators by the number of incidents?", + "query": "SELECT TOP 5 OperatorName, COUNT(*) AS IncidentCount FROM Aviation.Aircraft GROUP BY OperatorName ORDER BY IncidentCount DESC" + }, + { + "input": "Which incidents occurred in the year 2020?", + "query": "SELECT * FROM Aviation.Event WHERE YEAR(EventDate) = '2020'" + }, + { + "input": "What was the month with most events in the year 2020?", + "query": "SELECT TOP 1 MONTH(EventDate) EventMonth, COUNT(*) EventCount FROM Aviation.Event WHERE YEAR(EventDate) = '2020' GROUP BY MONTH(EventDate) ORDER BY EventCount DESC" + }, + { + "input": "How many crew members were involved in incidents?", + "query": "SELECT COUNT(*) FROM Aviation.Crew" + }, + { + "input": "List all incidents with detailed aircraft information for incidents that occurred in the year 2012.", + "query": "SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012" + }, + { + "input": "Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.", + "query": "SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5" + }, + { + "input": "List all crew members involved in incidents with serious injuries, along with the incident date and location.", + "query": "SELECT c.CrewNumber, c.Age, c.Sex, e.EventDate, e.LocationCity, e.LocationState FROM Aviation.Crew c JOIN Aviation.Event e ON c.EventId = e.EventId WHERE c.Injury = 'Serious'" + }, + ] + + def get_table_definitions_array(self, schema, table=None): + cursor = self.get_connection().cursor() + + # Base query to get columns information + query = """ + SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, PRIMARY_KEY, null EXTRA + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = %s + """ + + # Parameters for the query + params = [schema] + + # Adding optional filters + if table: + query += " AND TABLE_NAME = %s" + params.append(table) + + # Execute the query + cursor.execute(query, params) + + # Fetch the results + rows = cursor.fetchall() + + # Process the results to generate the table definition(s) + table_definitions = {} + for row in rows: + table_schema, table_name, column_name, column_type, is_nullable, column_default, column_key, extra = row + if table_name not in table_definitions: + table_definitions[table_name] = [] + table_definitions[table_name].append({ + "column_name": column_name, + "column_type": column_type, + "is_nullable": is_nullable, + "column_default": column_default, + "column_key": column_key, + "extra": extra + }) + + primary_keys = {} + + # Build the output string + result = [] + for table_name, columns in table_definitions.items(): + table_def = f"CREATE TABLE {schema}.{table_name} (\n" + column_definitions = [] + for column in columns: + column_def = f" {column['column_name']} {column['column_type']}" + if column['is_nullable'] == "NO": + column_def += " NOT NULL" + if column['column_default'] is not None: + column_def += f" DEFAULT {column['column_default']}" + if column['extra']: + column_def += f" {column['extra']}" + column_definitions.append(column_def) + if table_name in primary_keys: + pk_def = f" PRIMARY KEY ({', '.join(primary_keys[table_name])})" + column_definitions.append(pk_def) + table_def += ",\n".join(column_definitions) + table_def += "\n);" + result.append(table_def) + + return result + + def get_table_definitions(self, schema, table=None): + return "\n\n".join(self.get_table_definitions_array(schema=schema, table=table)) + + def get_ids_from_string_array(self, array): + return [str(hashlib.md5(x.encode()).hexdigest()) for x in array] + + def exists_in_db(self, collection_name, id): + schema_name = "SQLUser" + + cursor = self.get_connection().cursor() + query = f""" + SELECT TOP 1 id + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = %s and TABLE_NAME = %s + """ + params = [schema_name, collection_name] + cursor.execute(query, params) + rows = cursor.fetchall() + if len(rows) == 0: + return False + + del cursor, query, params, rows + + cursor = self.get_connection().cursor() + query = f""" + SELECT TOP 1 id + FROM {collection_name} + WHERE id = %s + """ + params = [id] + cursor.execute(query, params) + rows = cursor.fetchall() + return len(rows) > 0 + + def filter_not_in_collection(self, collection_name, docs_array, ids_array): + filtered = [x for x in zip(docs_array, ids_array) if not self.exists_in_db(collection_name, x[1])] + return list(zip(*filtered)) or ([], []) + + def schema_context_management(self, schema): + table_def = self.get_table_definitions_array(schema) + self.table_df = pd.DataFrame(data=table_def, columns=["col_def"]) + self.table_df["id"] = self.table_df.index + 1 + loader = DataFrameLoader(self.table_df, page_content_column="col_def") + documents = loader.load() + text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n") + self.tables_docs = text_splitter.split_documents(documents) + new_tables_docs, tables_docs_ids = self.filter_not_in_collection( + "sql_tables", + self.tables_docs, + self.get_ids_from_string_array([x.page_content for x in self.tables_docs]) + ) + self.tables_docs_ids = tables_docs_ids + self.tables_vector_store = IRISVector.from_documents( + embedding = OpenAIEmbeddings(openai_api_key=self.openai_api_key), + documents = self.tables_docs, + connection_string= self.iris_conn_str, + collection_name="sql_tables", + ids=self.tables_docs_ids + ) + + examples = self.get_examples() + new_sql_samples, sql_samples_ids = self.filter_not_in_collection( + "sql_samples", + examples, + self.get_ids_from_string_array([x['input'] for x in examples]) + ) + self.example_selector = SemanticSimilarityExampleSelector.from_examples( + new_sql_samples, + OpenAIEmbeddings(openai_api_key=self.openai_api_key), + IRISVector, + k=5, + input_keys=["input"], + connection_string=self.iris_conn_str, + collection_name="sql_samples", + ids=sql_samples_ids + ) + + def create_chain_model(self): + if not self.chain_model is None: + return self.chain_model + + iris_sql_template = """ +You are an InterSystems IRIS expert. Given an input question, first create a syntactically correct InterSystems IRIS query to run and return the answer to the input question. +Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the TOP clause as per InterSystems IRIS. You can order the results to return the most informative data in the database. +Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in single quotes ('') to denote them as delimited identifiers. +Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. +Pay attention to use CAST(CURRENT_DATE as date) function to get the current date, if the question involves "today". +Use double quotes to delimit columns identifiers. +Return just plain SQL; don't apply any kind of formatting. + """ + tables_prompt_template = """ + Only use the following tables: + {table_info} + """ + prompt_sql_few_shots_template = """ + Below are a number of examples of questions and their corresponding SQL queries. + + {examples_value} + """ + example_prompt_template = "User input: {input}\nSQL query: {query}" + example_prompt = PromptTemplate.from_template(example_prompt_template) + self.example_prompt = example_prompt + + user_prompt = "\n"+example_prompt.invoke({"input": "{input}", "query": ""}).to_string() + prompt = ( + ChatPromptTemplate.from_messages([("system", iris_sql_template)]) + + ChatPromptTemplate.from_messages([("system", tables_prompt_template)]) + + ChatPromptTemplate.from_messages([("system", prompt_sql_few_shots_template)]) + + ChatPromptTemplate.from_messages([("human", user_prompt)]) + ) + + model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=self.openai_api_key) + output_parser = StrOutputParser() + self.chain_model = prompt | model | output_parser + + def prompt(self, input): + self.context["input"] = input + + relevant_tables_docs = self.tables_vector_store.similarity_search(input) + relevant_tables_docs_indices = [x.metadata["id"] for x in relevant_tables_docs] + indices = self.table_df["id"].isin(relevant_tables_docs_indices) + relevant_tables_array = [x for x in self.table_df[indices]["col_def"]] + self.context["table_info"] = "\n\n".join(relevant_tables_array) + + self.context["examples_value"] = "\n\n".join([ + self.example_prompt.invoke(x).to_string() for x in self.example_selector.select_examples({"input": self.context["input"]}) + ]) + + self.log(self.context) + + response = self.create_chain_model().invoke({ + "top_k": self.context["top_k"], + "table_info": self.context["table_info"], + "examples_value": self.get_examples(), + "input": input + }) + return response + + def execute_query(self, query): + cursor = self.get_connection().cursor() + # Execute the query + cursor.execute(query) + + # Fetch the results + return cursor.fetchall() diff --git a/python/sqlzilla/sqlzilla copy.py b/python/sqlzilla/sqlzilla copy.py new file mode 100644 index 0000000..d1044e0 --- /dev/null +++ b/python/sqlzilla/sqlzilla copy.py @@ -0,0 +1,262 @@ +from sqlalchemy import create_engine +import hashlib +import pandas as pd; + +from langchain_core.prompts import PromptTemplate, ChatPromptTemplate +from langchain_core.example_selectors import SemanticSimilarityExampleSelector +from langchain_openai import OpenAIEmbeddings, ChatOpenAI +from langchain.docstore.document import Document +from langchain_community.document_loaders import DataFrameLoader +from langchain.text_splitter import CharacterTextSplitter +from langchain_core.output_parsers import StrOutputParser +from langchain_iris import IRISVector + +class SQLZilla: + _cnx = None + + def __init__(self, connection_string, openai_api_key, state): + self.iris_conn_str = connection_string + self.openai_api_key = openai_api_key + + if state._cnx is None: + self.log("criou") + self.engine = create_engine(connection_string) + state._cnx = self.engine.connect().connection + self.cnx = state._cnx + + self.context = {} + self.context["top_k"] = 3 + self.tables_vector_store = None + self.example_selector = None + self.chain_model = None + self.example_prompt = None + self.create_chain_model() + + def log(self, msg): + import os + os.write(1, f"{msg}\n".encode()) + + def create_examples_table(self): + sql = """ + CREATE TABLE IF NOT EXISTS sqlzilla.examples ( + id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, + prompt VARCHAR(255) NOT NULL, + query VARCHAR(255) NOT NULL + ); + """ + self.execute_query(sql) + + def get_examples(self): + sql = "SELECT prompt, query FROM sqlzilla.examples" + rows = self.execute_query(sql) + examples = [{ + "input": row[0], + "query": row[1] + } for row in rows] + return examples + + def get_table_definitions_array(self, schema, table=None): + cursor = self.cnx.cursor() + + # Base query to get columns information + query = """ + SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, PRIMARY_KEY, null EXTRA + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = %s + """ + + # Parameters for the query + params = [schema] + + # Adding optional filters + if table: + query += " AND TABLE_NAME = %s" + params.append(table) + + # Execute the query + cursor.execute(query, params) + + # Fetch the results + rows = cursor.fetchall() + + # Process the results to generate the table definition(s) + table_definitions = {} + for row in rows: + table_schema, table_name, column_name, column_type, is_nullable, column_default, column_key, extra = row + if table_name not in table_definitions: + table_definitions[table_name] = [] + table_definitions[table_name].append({ + "column_name": column_name, + "column_type": column_type, + "is_nullable": is_nullable, + "column_default": column_default, + "column_key": column_key, + "extra": extra + }) + + primary_keys = {} + + # Build the output string + result = [] + for table_name, columns in table_definitions.items(): + table_def = f"CREATE TABLE {schema}.{table_name} (\n" + column_definitions = [] + for column in columns: + column_def = f" {column['column_name']} {column['column_type']}" + if column['is_nullable'] == "NO": + column_def += " NOT NULL" + if column['column_default'] is not None: + column_def += f" DEFAULT {column['column_default']}" + if column['extra']: + column_def += f" {column['extra']}" + column_definitions.append(column_def) + if table_name in primary_keys: + pk_def = f" PRIMARY KEY ({', '.join(primary_keys[table_name])})" + column_definitions.append(pk_def) + table_def += ",\n".join(column_definitions) + table_def += "\n);" + result.append(table_def) + + return result + + def get_table_definitions(self, schema, table=None): + return "\n\n".join(self.get_table_definitions_array(schema=schema, table=table)) + + def get_ids_from_string_array(self, array): + return [str(hashlib.md5(x.encode()).hexdigest()) for x in array] + + def exists_in_db(self, collection_name, id): + schema_name = "SQLUser" + + cursor = self.cnx.cursor() + query = f""" + SELECT TOP 1 id + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = %s and TABLE_NAME = %s + """ + params = [schema_name, collection_name] + cursor.execute(query, params) + rows = cursor.fetchall() + if len(rows) == 0: + return False + + del cursor, query, params, rows + + cursor = self.cnx.cursor() + query = f""" + SELECT TOP 1 id + FROM {collection_name} + WHERE id = %s + """ + params = [id] + cursor.execute(query, params) + rows = cursor.fetchall() + return len(rows) > 0 + + def filter_not_in_collection(self, collection_name, docs_array, ids_array): + filtered = [x for x in zip(docs_array, ids_array) if not self.exists_in_db(collection_name, x[1])] + return list(zip(*filtered)) or ([], []) + + def schema_context_management(self, schema): + table_def = self.get_table_definitions_array(schema) + self.table_df = pd.DataFrame(data=table_def, columns=["col_def"]) + self.table_df["id"] = self.table_df.index + 1 + loader = DataFrameLoader(self.table_df, page_content_column="col_def") + documents = loader.load() + text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n") + self.tables_docs = text_splitter.split_documents(documents) + new_tables_docs, tables_docs_ids = self.filter_not_in_collection( + "sql_tables", + self.tables_docs, + self.get_ids_from_string_array([x.page_content for x in self.tables_docs]) + ) + self.tables_docs_ids = tables_docs_ids + self.tables_vector_store = IRISVector.from_documents( + embedding = OpenAIEmbeddings(openai_api_key=self.openai_api_key), + documents = self.tables_docs, + connection_string= self.iris_conn_str, + collection_name="sql_tables", + ids=self.tables_docs_ids + ) + + examples = self.get_examples() + new_sql_samples, sql_samples_ids = self.filter_not_in_collection( + "sql_samples", + examples, + self.get_ids_from_string_array([x['input'] for x in examples]) + ) + self.example_selector = SemanticSimilarityExampleSelector.from_examples( + new_sql_samples, + OpenAIEmbeddings(openai_api_key=self.openai_api_key), + IRISVector, + k=5, + input_keys=["input"], + connection_string=self.iris_conn_str, + collection_name="sql_samples", + ids=sql_samples_ids + ) + + def create_chain_model(self): + iris_sql_template = """ +You are an InterSystems IRIS expert. Given an input question, first create a syntactically correct InterSystems IRIS query to run and return the answer to the input question. +Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the TOP clause as per InterSystems IRIS. You can order the results to return the most informative data in the database. +Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in single quotes ('') to denote them as delimited identifiers. +Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. +Pay attention to use CAST(CURRENT_DATE as date) function to get the current date, if the question involves "today". +Use double quotes to delimit columns identifiers. +Return just plain SQL; don't apply any kind of formatting. + """ + tables_prompt_template = """ + Only use the following tables: + {table_info} + """ + prompt_sql_few_shots_template = """ + Below are a number of examples of questions and their corresponding SQL queries. + + {examples_value} + """ + example_prompt_template = "User input: {input}\nSQL query: {query}" + example_prompt = PromptTemplate.from_template(example_prompt_template) + user_prompt = "\n"+example_prompt.invoke({"input": "{input}", "query": ""}).to_string() + prompt = ( + ChatPromptTemplate.from_messages([("system", iris_sql_template)]) + + ChatPromptTemplate.from_messages([("system", tables_prompt_template)]) + + ChatPromptTemplate.from_messages([("system", prompt_sql_few_shots_template)]) + + ChatPromptTemplate.from_messages([("human", user_prompt)]) + ) + + model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=self.openai_api_key) + output_parser = StrOutputParser() + self.chain_model = prompt | model | output_parser + self.example_prompt = example_prompt + + def prompt(self, input): + self.context["input"] = input + + relevant_tables_docs = self.tables_vector_store.similarity_search(input) + relevant_tables_docs_indices = [x.metadata["id"] for x in relevant_tables_docs] + indices = self.table_df["id"].isin(relevant_tables_docs_indices) + relevant_tables_array = [x for x in self.table_df[indices]["col_def"]] + self.context["table_info"] = "\n\n".join(relevant_tables_array) + + self.context["examples_value"] = "\n\n".join([ + self.example_prompt.invoke(x).to_string() for x in self.example_selector.select_examples({"input": self.context["input"]}) + ]) + + self.log(self.context) + + response = self.chain_model.invoke({ + "top_k": self.context["top_k"], + "table_info": self.context["table_info"], + "examples_value": self.context["examples_value"], + "input": self.context["input"] + }) + return response + + def execute_query(self, query): + cursor = self.cnx.cursor() + # Execute the query + cursor.execute(query) + + # Fetch the results + return cursor.fetchall() diff --git a/python/sqlzilla/sqlzilla-main.py b/python/sqlzilla/sqlzilla-main.py new file mode 100644 index 0000000..9ee62c0 --- /dev/null +++ b/python/sqlzilla/sqlzilla-main.py @@ -0,0 +1,267 @@ +from sqlalchemy import create_engine +import hashlib +import pandas as pd; + +from langchain_core.prompts import PromptTemplate, ChatPromptTemplate +from langchain_core.example_selectors import SemanticSimilarityExampleSelector +from langchain_openai import OpenAIEmbeddings, ChatOpenAI +from langchain.docstore.document import Document +from langchain_community.document_loaders import DataFrameLoader +from langchain.text_splitter import CharacterTextSplitter +from langchain_core.output_parsers import StrOutputParser +from langchain_iris import IRISVector + +class SQLZilla: + def __init__(self, connection_string, openai_api_key): + self.openai_api_key = openai_api_key + self.iris_conn_str = connection_string + self.engine = create_engine(connection_string) + self.cnx = self.engine.connect().connection + self.context = {} + self.context["top_k"] = 3 + self.examples = [ + { + "input": "List all aircrafts.", + "query": "SELECT * FROM Aviation.Aircraft" + }, + { + "input": "Find all incidents for the aircraft with ID 'N12345'.", + "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE ID = 'N12345')" + }, + { + "input": "List all incidents in the 'Commercial' operation type.", + "query": "SELECT * FROM Aviation.Event WHERE EventId IN (SELECT EventId FROM Aviation.Aircraft WHERE OperationType = 'Commercial')" + }, + { + "input": "Find the total number of incidents.", + "query": "SELECT COUNT(*) FROM Aviation.Event" + }, + { + "input": "List all incidents that occurred in 'Canada'.", + "query": "SELECT * FROM Aviation.Event WHERE LocationCountry = 'Canada'" + }, + { + "input": "How many incidents are associated with the aircraft with AircraftKey 5?", + "query": "SELECT COUNT(*) FROM Aviation.Aircraft WHERE AircraftKey = 5" + }, + { + "input": "Find the total number of distinct aircrafts involved in incidents.", + "query": "SELECT COUNT(DISTINCT AircraftKey) FROM Aviation.Aircraft" + }, + { + "input": "List all incidents that occurred after 5 PM.", + "query": "SELECT * FROM Aviation.Event WHERE EventTime > 1700" + }, + { + "input": "Who are the top 5 operators by the number of incidents?", + "query": "SELECT TOP 5 OperatorName, COUNT(*) AS IncidentCount FROM Aviation.Aircraft GROUP BY OperatorName ORDER BY IncidentCount DESC" + }, + { + "input": "Which incidents occurred in the year 2020?", + "query": "SELECT * FROM Aviation.Event WHERE YEAR(EventDate) = '2020'" + }, + { + "input": "What was the month with most events in the year 2020?", + "query": "SELECT TOP 1 MONTH(EventDate) EventMonth, COUNT(*) EventCount FROM Aviation.Event WHERE YEAR(EventDate) = '2020' GROUP BY MONTH(EventDate) ORDER BY EventCount DESC" + }, + { + "input": "How many crew members were involved in incidents?", + "query": "SELECT COUNT(*) FROM Aviation.Crew" + }, + { + "input": "List all incidents with detailed aircraft information for incidents that occurred in the year 2012.", + "query": "SELECT e.EventId, e.EventDate, a.AircraftManufacturer, a.AircraftModel, a.AircraftCategory FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE Year(e.EventDate) = 2012" + }, + { + "input": "Find all incidents where there were more than 5 injuries and include the aircraft manufacturer and model.", + "query": "SELECT e.EventId, e.InjuriesTotal, a.AircraftManufacturer, a.AircraftModel FROM Aviation.Event e JOIN Aviation.Aircraft a ON e.EventId = a.EventId WHERE e.InjuriesTotal > 5" + }, + { + "input": "List all crew members involved in incidents with serious injuries, along with the incident date and location.", + "query": "SELECT c.CrewNumber, c.Age, c.Sex, e.EventDate, e.LocationCity, e.LocationState FROM Aviation.Crew c JOIN Aviation.Event e ON c.EventId = e.EventId WHERE c.Injury = 'Serious'" + }, + ] + + def get_table_definitions_array(self, schema, table=None): + cursor = self.cnx.cursor() + + # Base query to get columns information + query = """ + SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT, PRIMARY_KEY, null EXTRA + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = %s + """ + + # Parameters for the query + params = [schema] + + # Adding optional filters + if table: + query += " AND TABLE_NAME = %s" + params.append(table) + + # Execute the query + cursor.execute(query, params) + + # Fetch the results + rows = cursor.fetchall() + + # Process the results to generate the table definition(s) + table_definitions = {} + for row in rows: + table_schema, table_name, column_name, column_type, is_nullable, column_default, column_key, extra = row + if table_name not in table_definitions: + table_definitions[table_name] = [] + table_definitions[table_name].append({ + "column_name": column_name, + "column_type": column_type, + "is_nullable": is_nullable, + "column_default": column_default, + "column_key": column_key, + "extra": extra + }) + + primary_keys = {} + + # Build the output string + result = [] + for table_name, columns in table_definitions.items(): + table_def = f"CREATE TABLE {schema}.{table_name} (\n" + column_definitions = [] + for column in columns: + column_def = f" {column['column_name']} {column['column_type']}" + if column['is_nullable'] == "NO": + column_def += " NOT NULL" + if column['column_default'] is not None: + column_def += f" DEFAULT {column['column_default']}" + if column['extra']: + column_def += f" {column['extra']}" + column_definitions.append(column_def) + if table_name in primary_keys: + pk_def = f" PRIMARY KEY ({', '.join(primary_keys[table_name])})" + column_definitions.append(pk_def) + table_def += ",\n".join(column_definitions) + table_def += "\n);" + result.append(table_def) + + return result + + def get_table_definitions(self, schema, table=None): + return "\n\n".join(self.get_table_definitions_array(schema=schema, table=table)) + + def get_ids_from_string_array(self, array): + return [str(hashlib.md5(x.encode()).hexdigest()) for x in array] + + def exists_in_db(self, collection_name, id): + schema_name = "SQLUser" + + cursor = self.cnx.cursor() + query = f""" + SELECT TOP 1 id + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_SCHEMA = %s and TABLE_NAME = %s + """ + params = [schema_name, collection_name] + cursor.execute(query, params) + rows = cursor.fetchall() + if len(rows) == 0: + return False + + del cursor, query, params, rows + + cursor = self.cnx.cursor() + query = f""" + SELECT TOP 1 id + FROM {collection_name} + WHERE id = %s + """ + params = [id] + cursor.execute(query, params) + rows = cursor.fetchall() + return len(rows) > 0 + + def filter_not_in_collection(self, collection_name, docs_array, ids_array): + filtered = [x for x in zip(docs_array, ids_array) if not self.exists_in_db(collection_name, x[1])] + return list(zip(*filtered)) or ([], []) + + def schema_context_management(self, schema): + table_def = self.get_table_definitions_array(schema) + self.table_df = pd.DataFrame(data=table_def, columns=["col_def"]) + self.table_df["id"] = self.table_df.index + 1 + loader = DataFrameLoader(self.table_df, page_content_column="col_def") + documents = loader.load() + text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n") + self.tables_docs = text_splitter.split_documents(documents) + new_tables_docs, tables_docs_ids = self.filter_not_in_collection( + "sql_tables", + self.tables_docs, + self.get_ids_from_string_array([x.page_content for x in self.tables_docs]) + ) + self.tables_docs_ids = tables_docs_ids + + + def prompt(self, input): + self.context["input"] = input + db = IRISVector.from_documents( + embedding = OpenAIEmbeddings(openai_api_key=self.openai_api_key), + documents = self.tables_docs, + connection_string= self.iris_conn_str, + collection_name="sql_tables", + ids=self.tables_docs_ids + ) + relevant_tables_docs = db.similarity_search(input) + relevant_tables_docs_indices = [x.metadata["id"] for x in relevant_tables_docs] + indices = self.table_df["id"].isin(relevant_tables_docs_indices) + relevant_tables_array = [x for x in self.table_df[indices]["col_def"]] + self.context["table_info"] = "\n\n".join(relevant_tables_array) + new_sql_samples, sql_samples_ids = self.filter_not_in_collection( + "sql_samples", + self.examples, + self.get_ids_from_string_array([x['input'] for x in self.examples]) + ) + iris_sql_template = """ +You are an InterSystems IRIS expert. Given an input question, first create a syntactically correct InterSystems IRIS query to run and return the answer to the input question. +Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the TOP clause as per InterSystems IRIS. You can order the results to return the most informative data in the database. +Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in single quotes ('') to denote them as delimited identifiers. +Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. +Pay attention to use CAST(CURRENT_DATE as date) function to get the current date, if the question involves "today". +Use double quotes to delimit columns identifiers. +Return just plain SQL; don't apply any kind of formatting. + """ + tables_prompt_template = """ + Only use the following tables: + {table_info} + """ + prompt_sql_few_shots_template = """ + Below are a number of examples of questions and their corresponding SQL queries. + + {examples_value} + """ + example_prompt_template = "User input: {input}\nSQL query: {query}" + example_prompt = PromptTemplate.from_template(example_prompt_template) + user_prompt = "\n"+example_prompt.invoke({"input": "{input}", "query": ""}).to_string() + prompt = ( + ChatPromptTemplate.from_messages([("system", iris_sql_template)]) + + ChatPromptTemplate.from_messages([("system", tables_prompt_template)]) + + ChatPromptTemplate.from_messages([("system", prompt_sql_few_shots_template)]) + + ChatPromptTemplate.from_messages([("human", user_prompt)]) + ) + + model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=self.openai_api_key) + output_parser = StrOutputParser() + chain_model = prompt | model | output_parser + response = chain_model.invoke({ + "top_k": self.context["top_k"], + "table_info": self.context["table_info"], + "examples_value": self.examples, + "input": input + }) + return response + + def execute_query(self, query): + cursor = self.cnx.cursor() + # Execute the query + cursor.execute(query) + + # Fetch the results + return cursor.fetchall() diff --git a/python/sqlzilla/sqlzilla.py b/python/sqlzilla/sqlzilla.py index 9ee62c0..eafb4c0 100644 --- a/python/sqlzilla/sqlzilla.py +++ b/python/sqlzilla/sqlzilla.py @@ -13,13 +13,23 @@ class SQLZilla: def __init__(self, connection_string, openai_api_key): + self.log('criou') self.openai_api_key = openai_api_key - self.iris_conn_str = connection_string + # self.iris_conn_str = connection_string self.engine = create_engine(connection_string) - self.cnx = self.engine.connect().connection + self.conn_wrapper = self.engine.connect() + self.connection = self.conn_wrapper.connection + self.log('connection opened') self.context = {} self.context["top_k"] = 3 - self.examples = [ + self.tables_vector_store = None + self.example_selector = None + self.chain_model = None + self.example_prompt = None + self.create_chain_model() + + def get_examples(self): + return [ { "input": "List all aircrafts.", "query": "SELECT * FROM Aviation.Aircraft" @@ -82,8 +92,20 @@ def __init__(self, connection_string, openai_api_key): }, ] + def __del__(self): + self.log('deletou') + if not self.connection is None: + self.log('connection closed') + self.connection.close() + if not self.engine is None: + self.engine.dispose() + + def log(self, msg): + import os + os.write(1, f"{msg}\n".encode()) + def get_table_definitions_array(self, schema, table=None): - cursor = self.cnx.cursor() + cursor = self.connection.cursor() # Base query to get columns information query = """ @@ -155,7 +177,7 @@ def get_ids_from_string_array(self, array): def exists_in_db(self, collection_name, id): schema_name = "SQLUser" - cursor = self.cnx.cursor() + cursor = self.connection.cursor() query = f""" SELECT TOP 1 id FROM INFORMATION_SCHEMA.TABLES @@ -169,7 +191,7 @@ def exists_in_db(self, collection_name, id): del cursor, query, params, rows - cursor = self.cnx.cursor() + cursor = self.connection.cursor() query = f""" SELECT TOP 1 id FROM {collection_name} @@ -198,27 +220,39 @@ def schema_context_management(self, schema): self.get_ids_from_string_array([x.page_content for x in self.tables_docs]) ) self.tables_docs_ids = tables_docs_ids + if self.tables_vector_store is None: + self.tables_vector_store = IRISVector.from_documents( + embedding = OpenAIEmbeddings(openai_api_key=self.openai_api_key), + documents = self.tables_docs, + # connection_string= self.iris_conn_str, + connection=self.conn_wrapper, + collection_name="sql_tables", + ids=self.tables_docs_ids + ) + + if self.example_selector is None: + examples = self.get_examples() + new_sql_samples, sql_samples_ids = self.filter_not_in_collection( + "sql_samples", + examples, + self.get_ids_from_string_array([x['input'] for x in examples]) + ) + self.example_selector = SemanticSimilarityExampleSelector.from_examples( + new_sql_samples, + OpenAIEmbeddings(openai_api_key=self.openai_api_key), + IRISVector, + k=5, + input_keys=["input"], + # connection_string=self.iris_conn_str, + connection=self.conn_wrapper, + collection_name="sql_samples", + ids=sql_samples_ids + ) - - def prompt(self, input): - self.context["input"] = input - db = IRISVector.from_documents( - embedding = OpenAIEmbeddings(openai_api_key=self.openai_api_key), - documents = self.tables_docs, - connection_string= self.iris_conn_str, - collection_name="sql_tables", - ids=self.tables_docs_ids - ) - relevant_tables_docs = db.similarity_search(input) - relevant_tables_docs_indices = [x.metadata["id"] for x in relevant_tables_docs] - indices = self.table_df["id"].isin(relevant_tables_docs_indices) - relevant_tables_array = [x for x in self.table_df[indices]["col_def"]] - self.context["table_info"] = "\n\n".join(relevant_tables_array) - new_sql_samples, sql_samples_ids = self.filter_not_in_collection( - "sql_samples", - self.examples, - self.get_ids_from_string_array([x['input'] for x in self.examples]) - ) + def create_chain_model(self): + if not self.chain_model is None: + return self.chain_model + iris_sql_template = """ You are an InterSystems IRIS expert. Given an input question, first create a syntactically correct InterSystems IRIS query to run and return the answer to the input question. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the TOP clause as per InterSystems IRIS. You can order the results to return the most informative data in the database. @@ -239,6 +273,8 @@ def prompt(self, input): """ example_prompt_template = "User input: {input}\nSQL query: {query}" example_prompt = PromptTemplate.from_template(example_prompt_template) + self.example_prompt = example_prompt + user_prompt = "\n"+example_prompt.invoke({"input": "{input}", "query": ""}).to_string() prompt = ( ChatPromptTemplate.from_messages([("system", iris_sql_template)]) @@ -249,17 +285,33 @@ def prompt(self, input): model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=self.openai_api_key) output_parser = StrOutputParser() - chain_model = prompt | model | output_parser - response = chain_model.invoke({ + self.chain_model = prompt | model | output_parser + + def prompt(self, input): + self.context["input"] = input + + relevant_tables_docs = self.tables_vector_store.similarity_search(input) + relevant_tables_docs_indices = [x.metadata["id"] for x in relevant_tables_docs] + indices = self.table_df["id"].isin(relevant_tables_docs_indices) + relevant_tables_array = [x for x in self.table_df[indices]["col_def"]] + self.context["table_info"] = "\n\n".join(relevant_tables_array) + + self.context["examples_value"] = "\n\n".join([ + self.example_prompt.invoke(x).to_string() for x in self.example_selector.select_examples({"input": self.context["input"]}) + ]) + + self.log(self.context) + + response = self.create_chain_model().invoke({ "top_k": self.context["top_k"], "table_info": self.context["table_info"], - "examples_value": self.examples, + "examples_value": self.context["examples_value"], "input": input }) return response def execute_query(self, query): - cursor = self.cnx.cursor() + cursor = self.connection.cursor() # Execute the query cursor.execute(query)