Skip to content

Commit

Permalink
followup questions
Browse files Browse the repository at this point in the history
  • Loading branch information
zainhoda committed Sep 22, 2023
1 parent fc9430a commit b4efdf1
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 15 deletions.
28 changes: 28 additions & 0 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import plotly.express as px
import plotly.graph_objects as go
import requests
import re

from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
from ..types import TrainingPlan, TrainingPlanItem
Expand All @@ -35,6 +36,22 @@ def generate_sql(self, question: str, **kwargs) -> str:
llm_response = self.submit_prompt(prompt, **kwargs)
return llm_response

def generate_followup_questions(self, question: str, **kwargs) -> str:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
prompt = self.get_followup_questions_prompt(
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list,
**kwargs,
)
llm_response = self.submit_prompt(prompt, **kwargs)

numbers_removed = re.sub(r'^\d+\.\s*', '', llm_response, flags=re.MULTILINE)
return numbers_removed.split("\n")

def generate_questions(self, **kwargs) -> list[str]:
"""
**Example:**
Expand Down Expand Up @@ -99,6 +116,17 @@ def get_sql_prompt(
):
pass

@abstractmethod
def get_followup_questions_prompt(
self,
question: str,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs
):
pass

@abstractmethod
def submit_prompt(self, prompt, **kwargs) -> str:
pass
Expand Down
80 changes: 65 additions & 15 deletions src/vanna/openai/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import abstractmethod

import openai
import pandas as pd

from ..base import VannaBase

Expand Down Expand Up @@ -37,29 +38,56 @@ def user_message(message: str) -> dict:
def assistant_message(message: str) -> dict:
return {"role": "assistant", "content": message}

@staticmethod
def str_to_approx_token_count(string: str) -> int:
return len(string) / 4

@staticmethod
def add_ddl_to_prompt(initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000) -> str:
if len(ddl_list) > 0:
initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for ddl in ddl_list:
if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(ddl) < max_tokens:
initial_prompt += f"{ddl}\n\n"

return initial_prompt

@staticmethod
def add_documentation_to_prompt(initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000) -> str:
if len(documentation_list) > 0:
initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for documentation in documentation_list:
if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(documentation) < max_tokens:
initial_prompt += f"{documentation}\n\n"

return initial_prompt

@staticmethod
def add_sql_to_prompt(initial_prompt: str, sql_list: list[str], max_tokens: int = 14000) -> str:
if len(sql_list) > 0:
initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for question in sql_list:
if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(question["sql"]) < max_tokens:
initial_prompt += f"{question['question']}\n{question['sql']}\n\n"

return initial_prompt

def get_sql_prompt(
self,
question: str,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs,
) -> str:
):
initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n"

if len(ddl_list) > 0:
initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
initial_prompt = OpenAI_Chat.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000)

for ddl in ddl_list:
if len(initial_prompt) < 50000: # Add DDL if it fits
initial_prompt += f"{ddl}\n\n"

if len(doc_list) > 0:
initial_prompt += f"The following information may or may not be useful in constructing the SQL to answer the question\n"

for doc in doc_list:
if len(initial_prompt) < 60000: # Add Documentation if it fits
initial_prompt += f"{doc}\n\n"
initial_prompt = OpenAI_Chat.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000)

message_log = [OpenAI_Chat.system_message(initial_prompt)]

Expand All @@ -75,6 +103,28 @@ def get_sql_prompt(

return message_log

def get_followup_questions_prompt(
self,
question: str,
df: pd.DataFrame,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs
):
initial_prompt = f"The user initially asked the question: '{question}': \n\n"

initial_prompt = OpenAI_Chat.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000)

initial_prompt = OpenAI_Chat.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000)

initial_prompt = OpenAI_Chat.add_sql_to_prompt(initial_prompt, question_sql_list, max_tokens=14000)

message_log = [OpenAI_Chat.system_message(initial_prompt)]
message_log.append(OpenAI_Chat.user_message("Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions."))

return message_log

def generate_question(self, sql: str, **kwargs) -> str:
response = self.submit_prompt(
[
Expand Down Expand Up @@ -150,7 +200,7 @@ def submit_prompt(self, prompt, **kwargs) -> str:
len(message["content"]) / 4
) # Use 4 as an approximation for the number of characters per token

if "engine" in self.config:
if self.config is not None and "engine" in self.config:
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
Expand All @@ -161,7 +211,7 @@ def submit_prompt(self, prompt, **kwargs) -> str:
stop=None,
temperature=0.7,
)
elif "model" in self.config:
elif self.config is not None and "model" in self.config:
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
Expand Down
50 changes: 50 additions & 0 deletions src/vanna/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,19 @@ def get_sql_prompt(
Not necessary for remote models as prompts are generated on the server side.
"""

def get_followup_questions_prompt(
self,
question: str,
df: pd.DataFrame,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs,
):
"""
Not necessary for remote models as prompts are generated on the server side.
"""

def submit_prompt(self, prompt, **kwargs) -> str:
"""
Not necessary for remote models as prompts are handled on the server side.
Expand Down Expand Up @@ -420,3 +433,40 @@ def generate_sql(self, question: str, **kwargs) -> str:
sql_answer = SQLAnswer(**d["result"])

return sql_answer.sql

def generate_followup_questions(self, question: str, df: pd.DataFrame, **kwargs) -> list[str]:
"""
**Example:**
```python
vn.generate_followup_questions(question="What is the average salary of employees?", df=df)
# ['What is the average salary of employees in the Sales department?', 'What is the average salary of employees in the Engineering department?', ...]
```
Generate follow-up questions using the Vanna.AI API.
Args:
question (str): The question to generate follow-up questions for.
df (pd.DataFrame): The DataFrame to generate follow-up questions for.
Returns:
List[str] or None: The follow-up questions, or None if an error occurred.
"""
params = [
DataResult(
question=question,
sql=None,
table_markdown="",
error=None,
correction_attempts=0,
)
]

d = self._rpc_call(method="generate_followup_questions", params=params)

if "result" not in d:
return None

# Load the result into a dataclass
question_string_list = QuestionStringList(**d["result"])

return question_string_list.questions

0 comments on commit b4efdf1

Please sign in to comment.