Skip to content

Commit

Permalink
🐛 Send prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
henryhamon committed Jul 29, 2024
1 parent d8996d7 commit 5d88b2a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
13 changes: 7 additions & 6 deletions python/sqlzilla/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ def assistant_interaction(sqlzilla, prompt):
response = sqlzilla.prompt(prompt)
st.session_state.chat_history.append({"role": "user", "content": prompt})
st.session_state.chat_history.append({"role": "assistant", "content": response})

# Check if the response contains SQL code and update the editor
if "SELECT" in response.upper():
st.session_state.query_result = response

return response

Expand Down Expand Up @@ -93,6 +89,7 @@ def assistant_interaction(sqlzilla, prompt):

# Initial prompts for namespace and database schema
database_schema = st.text_input('Enter Database Schema')
editor = code_editor("-- your query", lang="sql", height=[10, 100], shortcuts="vscode")

if st.session_state.namespace and database_schema and st.session_state.openai_api_key:
sqlzilla = SQLZilla(db_connection_str(), st.session_state.openai_api_key)
Expand All @@ -102,8 +99,7 @@ def assistant_interaction(sqlzilla, prompt):
col1, col2 = st.columns(2)

with col1:
code_editor("-- your query", lang="sql", height=[10, 100], shortcuts="vscode")

st.write(editor)
# Buttons to run, save, and clear the query in a single row
run_button, clear_button = st.columns([1, 1])
with run_button:
Expand Down Expand Up @@ -131,6 +127,11 @@ def assistant_interaction(sqlzilla, prompt):
st.session_state.chat_history.append({"role": "user", "content": prompt})

response = assistant_interaction(sqlzilla, prompt)

# Check if the response contains SQL code and update the editor
if "SELECT" in response.upper():
st.session_state.query_result = response
editor.text = response
# Display assistant response in chat message container
with st.chat_message("assistant"):
st.markdown(response)
Expand Down
13 changes: 4 additions & 9 deletions python/sqlzilla/sqlzilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,19 +204,20 @@ def schema_context_management(self, schema):


def prompt(self, input):
self.context["input"] = input
db = IRISVector.from_documents(
embedding = OpenAIEmbeddings(openai_api_key=self.openai_api_key),
documents = self.tables_docs,
connection_string= self.iris_conn_str,
collection_name="sql_tables",
ids=self.tables_docs_ids
)
relevant_tables_docs = db.similarity_search(self.context["input"])
relevant_tables_docs = db.similarity_search(input)
relevant_tables_docs_indices = [x.metadata["id"] for x in relevant_tables_docs]
indices = self.table_df["id"].isin(relevant_tables_docs_indices)
relevant_tables_array = [x for x in self.table_df[indices]["col_def"]]
self.context["table_info"] = "\n\n".join(relevant_tables_array)
new_sql_samples, sql_samples_ids = self.ilter_not_in_collection(
new_sql_samples, sql_samples_ids = self.filter_not_in_collection(
"sql_samples",
self.examples,
self.get_ids_from_string_array([x['input'] for x in self.examples])
Expand Down Expand Up @@ -246,20 +247,14 @@ def prompt(self, input):
+ ChatPromptTemplate.from_messages([("system", prompt_sql_few_shots_template)])
+ ChatPromptTemplate.from_messages([("human", user_prompt)])
)
prompt_value = prompt.invoke({
"top_k": self.context["top_k"],
"table_info": self.context["table_info"],
"examples_value": self.context["examples_value"],
"input": input
})

model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=self.openai_api_key)
output_parser = StrOutputParser()
chain_model = prompt | model | output_parser
response = chain_model.invoke({
"top_k": self.context["top_k"],
"table_info": self.context["table_info"],
"examples_value": self.context["examples_value"],
"examples_value": self.examples,
"input": input
})
return response

0 comments on commit 5d88b2a

Please sign in to comment.