Skip to content

Commit

Permalink
one iris vector collection per schema
Browse files Browse the repository at this point in the history
  • Loading branch information
jrpereirajr committed Aug 15, 2024
1 parent e04b455 commit b6cd128
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
4 changes: 3 additions & 1 deletion python/sqlzilla/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def clean_response(response):

database_schema = None
if (st.session_state.namespace and st.session_state.openai_api_key):
sqlzilla = SQLZilla(db_connection_str(), st.session_state.openai_api_key, database_schema)
sqlzilla = SQLZilla(db_connection_str(), st.session_state.openai_api_key)
# Initial prompts for namespace and database schema
try:
query = """
Expand All @@ -98,6 +98,7 @@ def clean_response(response):
index=None,
placeholder="Select database schema...",
)
sqlzilla.schema_name = database_schema
except:
database_schema = st.text_input('Enter Database Schema')
st.warning('Was not possible to retrieve database schemas. Please provide it manually.')
Expand Down Expand Up @@ -130,6 +131,7 @@ def clean_response(response):

if st.button("Save on library"):
sqlzilla.add_example(st.session_state.prompt, st.session_state.code_text)
st.success("Saved on library!")

with col2:
# Display chat history
Expand Down
19 changes: 12 additions & 7 deletions python/sqlzilla/sqlzilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from langchain_iris import IRISVector

class SQLZilla:
def __init__(self, connection_string, openai_api_key, schema_name="SQLUser"):
def __init__(self, connection_string, openai_api_key):
self.log('criou')
self.openai_api_key = openai_api_key
self.schema_name = schema_name
self.schema_name = None
self.engine = create_engine(connection_string)
self.conn_wrapper = self.engine.connect()
self.connection = self.conn_wrapper.connection
Expand All @@ -42,6 +42,8 @@ def create_examples_table(self):

def get_examples(self):
sql = "SELECT prompt, query FROM sqlzilla.examples WHERE schema_name = %s"
self.log('sql: ' + sql)
self.log('params: ' + str([self.schema_name]))
rows = self.execute_query(sql, [self.schema_name])
examples = [{
"input": row[0],
Expand All @@ -50,7 +52,7 @@ def get_examples(self):
return examples

def add_example(self, prompt, query):
sql = "INSERT INTO sqlzilla.examples (prompt, query, schema) VALUES (%s, %s, %s)"
sql = "INSERT INTO sqlzilla.examples (prompt, query, schema_name) VALUES (%s, %s, %s)"
self.execute_query(sql, [prompt, query, self.schema_name])

def __del__(self):
Expand Down Expand Up @@ -176,8 +178,10 @@ def schema_context_management(self, schema):
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=20, separator="\n")
self.tables_docs = text_splitter.split_documents(documents)
self.log('schema_name: ' + str(self.schema_name))
collection_name_tables = "sql_tables_"+self.schema_name
new_tables_docs, tables_docs_ids = self.filter_not_in_collection(
"sql_tables",
collection_name_tables,
self.tables_docs,
self.get_ids_from_string_array([x.page_content for x in self.tables_docs])
)
Expand All @@ -186,14 +190,15 @@ def schema_context_management(self, schema):
embedding = OpenAIEmbeddings(openai_api_key=self.openai_api_key),
documents = self.tables_docs,
connection=self.conn_wrapper,
collection_name="sql_tables",
collection_name=collection_name_tables,
ids=self.tables_docs_ids
)

if self.example_selector is None:
examples = self.get_examples()
collection_name_examples = "sql_samples_"+self.schema_name
new_sql_samples, sql_samples_ids = self.filter_not_in_collection(
"sql_samples",
collection_name_examples,
examples,
self.get_ids_from_string_array([x['input'] for x in examples])
)
Expand All @@ -204,7 +209,7 @@ def schema_context_management(self, schema):
k=5,
input_keys=["input"],
connection=self.conn_wrapper,
collection_name="sql_samples",
collection_name=collection_name_examples,
ids=sql_samples_ids
)

Expand Down

0 comments on commit b6cd128

Please sign in to comment.