Skip to content

Commit

Permalink
Merge pull request #107 from vanna-ai/additional-options
Browse files Browse the repository at this point in the history
Additional options for engine and warehouse
  • Loading branch information
zainhoda authored Aug 31, 2023
2 parents 40ec1c3 + ff22214 commit a791c89
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "vanna"
version = "0.0.22"
version = "0.0.23"
authors = [
{ name="Zain Hoda", email="[email protected]" },
]
Expand Down
4 changes: 4 additions & 0 deletions src/vanna/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def connect_to_snowflake(
password: str,
database: str,
role: Union[str, None] = None,
warehouse: Union[str, None] = None,
):
try:
snowflake = __import__("snowflake.connector")
Expand Down Expand Up @@ -150,6 +151,9 @@ def run_sql_snowflake(sql: str) -> pd.DataFrame:

if role is not None:
cs.execute(f"USE ROLE {role}")

if warehouse is not None:
cs.execute(f"USE WAREHOUSE {warehouse}")
cs.execute(f"USE DATABASE {database}")

cur = cs.execute(sql)
Expand Down
30 changes: 21 additions & 9 deletions src/vanna/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def generate_question(self, sql: str, **kwargs) -> str:
"The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."
),
self.user_message(sql),
]
],
**kwargs,
)

return response
Expand Down Expand Up @@ -149,16 +150,27 @@ def submit_prompt(self, prompt, **kwargs) -> str:
len(message["content"]) / 4
) # Use 4 as an approximation for the number of characters per token

if num_tokens > 3500:
model = "gpt-3.5-turbo-16k"
if "engine" in self.config:
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
response = openai.ChatCompletion.create(
engine=self.config["engine"],
messages=prompt,
max_tokens=500,
stop=None,
temperature=0.7,
)
else:
model = "gpt-3.5-turbo"

print(f"Using model {model} for {num_tokens} tokens (approx)")
if num_tokens > 3500:
model = "gpt-3.5-turbo-16k"
else:
model = "gpt-3.5-turbo"

response = openai.ChatCompletion.create(
model=model, messages=prompt, max_tokens=500, stop=None, temperature=0.7
)
print(f"Using model {model} for {num_tokens} tokens (approx)")
response = openai.ChatCompletion.create(
model=model, messages=prompt, max_tokens=500, stop=None, temperature=0.7
)

for (
choice
Expand Down

0 comments on commit a791c89

Please sign in to comment.