diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index a5c0150d..17d364d7 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -10,6 +10,7 @@ import plotly.express as px import plotly.graph_objects as go import requests +import re from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError from ..types import TrainingPlan, TrainingPlanItem @@ -35,6 +36,22 @@ def generate_sql(self, question: str, **kwargs) -> str: llm_response = self.submit_prompt(prompt, **kwargs) return llm_response + def generate_followup_questions(self, question: str, **kwargs) -> str: + question_sql_list = self.get_similar_question_sql(question, **kwargs) + ddl_list = self.get_related_ddl(question, **kwargs) + doc_list = self.get_related_documentation(question, **kwargs) + prompt = self.get_followup_questions_prompt( + question=question, + question_sql_list=question_sql_list, + ddl_list=ddl_list, + doc_list=doc_list, + **kwargs, + ) + llm_response = self.submit_prompt(prompt, **kwargs) + + numbers_removed = re.sub(r'^\d+\.\s*', '', llm_response, flags=re.MULTILINE) + return numbers_removed.split("\n") + def generate_questions(self, **kwargs) -> list[str]: """ **Example:** @@ -99,6 +116,17 @@ def get_sql_prompt( ): pass + @abstractmethod + def get_followup_questions_prompt( + self, + question: str, + question_sql_list: list, + ddl_list: list, + doc_list: list, + **kwargs + ): + pass + @abstractmethod def submit_prompt(self, prompt, **kwargs) -> str: pass diff --git a/src/vanna/openai/openai_chat.py b/src/vanna/openai/openai_chat.py index 1646168e..72ff8a36 100644 --- a/src/vanna/openai/openai_chat.py +++ b/src/vanna/openai/openai_chat.py @@ -2,6 +2,7 @@ from abc import abstractmethod import openai +import pandas as pd from ..base import VannaBase @@ -37,6 +38,43 @@ def user_message(message: str) -> dict: def assistant_message(message: str) -> dict: return {"role": "assistant", "content": message} + @staticmethod + def str_to_approx_token_count(string: str) -> int: + return len(string) / 4 + + @staticmethod + def add_ddl_to_prompt(initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000) -> str: + if len(ddl_list) > 0: + initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" + + for ddl in ddl_list: + if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(ddl) < max_tokens: + initial_prompt += f"{ddl}\n\n" + + return initial_prompt + + @staticmethod + def add_documentation_to_prompt(initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000) -> str: + if len(documentation_list) > 0: + initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" + + for documentation in documentation_list: + if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(documentation) < max_tokens: + initial_prompt += f"{documentation}\n\n" + + return initial_prompt + + @staticmethod + def add_sql_to_prompt(initial_prompt: str, sql_list: list[str], max_tokens: int = 14000) -> str: + if len(sql_list) > 0: + initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" + + for question in sql_list: + if OpenAI_Chat.str_to_approx_token_count(initial_prompt) + OpenAI_Chat.str_to_approx_token_count(question["sql"]) < max_tokens: + initial_prompt += f"{question['question']}\n{question['sql']}\n\n" + + return initial_prompt + def get_sql_prompt( self, question: str, @@ -44,22 +82,12 @@ def get_sql_prompt( ddl_list: list, doc_list: list, **kwargs, - ) -> str: + ): initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n" - if len(ddl_list) > 0: - initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" + initial_prompt = OpenAI_Chat.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000) - for ddl in ddl_list: - if len(initial_prompt) < 50000: # Add DDL if it fits - initial_prompt += f"{ddl}\n\n" - - if len(doc_list) > 0: - initial_prompt += f"The following information may or may not be useful in constructing the SQL to answer the question\n" - - for doc in doc_list: - if len(initial_prompt) < 60000: # Add Documentation if it fits - initial_prompt += f"{doc}\n\n" + initial_prompt = OpenAI_Chat.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000) message_log = [OpenAI_Chat.system_message(initial_prompt)] @@ -75,6 +103,28 @@ def get_sql_prompt( return message_log + def get_followup_questions_prompt( + self, + question: str, + df: pd.DataFrame, + question_sql_list: list, + ddl_list: list, + doc_list: list, + **kwargs + ): + initial_prompt = f"The user initially asked the question: '{question}': \n\n" + + initial_prompt = OpenAI_Chat.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000) + + initial_prompt = OpenAI_Chat.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000) + + initial_prompt = OpenAI_Chat.add_sql_to_prompt(initial_prompt, question_sql_list, max_tokens=14000) + + message_log = [OpenAI_Chat.system_message(initial_prompt)] + message_log.append(OpenAI_Chat.user_message("Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions.")) + + return message_log + def generate_question(self, sql: str, **kwargs) -> str: response = self.submit_prompt( [ @@ -150,7 +200,7 @@ def submit_prompt(self, prompt, **kwargs) -> str: len(message["content"]) / 4 ) # Use 4 as an approximation for the number of characters per token - if "engine" in self.config: + if self.config is not None and "engine" in self.config: print( f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" ) @@ -161,7 +211,7 @@ def submit_prompt(self, prompt, **kwargs) -> str: stop=None, temperature=0.7, ) - elif "model" in self.config: + elif self.config is not None and "model" in self.config: print( f"Using model {self.config['model']} for {num_tokens} tokens (approx)" ) diff --git a/src/vanna/remote.py b/src/vanna/remote.py index 194babd5..c5f843ae 100644 --- a/src/vanna/remote.py +++ b/src/vanna/remote.py @@ -373,6 +373,19 @@ def get_sql_prompt( Not necessary for remote models as prompts are generated on the server side. """ + def get_followup_questions_prompt( + self, + question: str, + df: pd.DataFrame, + question_sql_list: list, + ddl_list: list, + doc_list: list, + **kwargs, + ): + """ + Not necessary for remote models as prompts are generated on the server side. + """ + def submit_prompt(self, prompt, **kwargs) -> str: """ Not necessary for remote models as prompts are handled on the server side. @@ -420,3 +433,40 @@ def generate_sql(self, question: str, **kwargs) -> str: sql_answer = SQLAnswer(**d["result"]) return sql_answer.sql + + def generate_followup_questions(self, question: str, df: pd.DataFrame, **kwargs) -> list[str]: + """ + **Example:** + ```python + vn.generate_followup_questions(question="What is the average salary of employees?", df=df) + # ['What is the average salary of employees in the Sales department?', 'What is the average salary of employees in the Engineering department?', ...] + ``` + + Generate follow-up questions using the Vanna.AI API. + + Args: + question (str): The question to generate follow-up questions for. + df (pd.DataFrame): The DataFrame to generate follow-up questions for. + + Returns: + List[str] or None: The follow-up questions, or None if an error occurred. + """ + params = [ + DataResult( + question=question, + sql=None, + table_markdown="", + error=None, + correction_attempts=0, + ) + ] + + d = self._rpc_call(method="generate_followup_questions", params=params) + + if "result" not in d: + return None + + # Load the result into a dataclass + question_string_list = QuestionStringList(**d["result"]) + + return question_string_list.questions \ No newline at end of file