diff --git a/pyproject.toml b/pyproject.toml index ebb638a7..b877c47e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "vanna" -version = "0.0.22" +version = "0.0.23" authors = [ { name="Zain Hoda", email="zain@vanna.ai" }, ] diff --git a/src/vanna/base.py b/src/vanna/base.py index 49d47130..ce70b2fa 100644 --- a/src/vanna/base.py +++ b/src/vanna/base.py @@ -97,6 +97,7 @@ def connect_to_snowflake( password: str, database: str, role: Union[str, None] = None, + warehouse: Union[str, None] = None, ): try: snowflake = __import__("snowflake.connector") @@ -150,6 +151,9 @@ def run_sql_snowflake(sql: str) -> pd.DataFrame: if role is not None: cs.execute(f"USE ROLE {role}") + + if warehouse is not None: + cs.execute(f"USE WAREHOUSE {warehouse}") cs.execute(f"USE DATABASE {database}") cur = cs.execute(sql) diff --git a/src/vanna/openai_chat.py b/src/vanna/openai_chat.py index 7f3de653..02e91309 100644 --- a/src/vanna/openai_chat.py +++ b/src/vanna/openai_chat.py @@ -82,7 +82,8 @@ def generate_question(self, sql: str, **kwargs) -> str: "The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question." ), self.user_message(sql), - ] + ], + **kwargs, ) return response @@ -149,16 +150,27 @@ def submit_prompt(self, prompt, **kwargs) -> str: len(message["content"]) / 4 ) # Use 4 as an approximation for the number of characters per token - if num_tokens > 3500: - model = "gpt-3.5-turbo-16k" + if "engine" in self.config: + print( + f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" + ) + response = openai.ChatCompletion.create( + engine=self.config["engine"], + messages=prompt, + max_tokens=500, + stop=None, + temperature=0.7, + ) else: - model = "gpt-3.5-turbo" - - print(f"Using model {model} for {num_tokens} tokens (approx)") + if num_tokens > 3500: + model = "gpt-3.5-turbo-16k" + else: + model = "gpt-3.5-turbo" - response = openai.ChatCompletion.create( - model=model, messages=prompt, max_tokens=500, stop=None, temperature=0.7 - ) + print(f"Using model {model} for {num_tokens} tokens (approx)") + response = openai.ChatCompletion.create( + model=model, messages=prompt, max_tokens=500, stop=None, temperature=0.7 + ) for ( choice