-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathopenai_utils.py
63 lines (49 loc) · 1.63 KB
/
openai_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import openai
def create_table_definition_prompt(df, table_name):
"""This function creates a prompt for the OpenAI API to generate SQL queries.
Args:
df (dataframe): pd.DataFrame object to automtically extract the table columns
table_name (string): Name of the table within the database
Returns: string containing the prompt for OpenAI
"""
prompt = '''### sqlite table, with its properties:
#
# {}({})
#
'''.format(table_name, ",".join(str(x) for x in df.columns))
return prompt
def user_query_input():
"""Ask the user what they want to know about the data.
Returns:
string: User input
"""
user_input = input("Tell OpenAi what you want to know about the data: ")
return user_input
def combine_prompts(fixed_sql_prompt, user_query):
"""Combine the fixed SQL prompt with the user query.
Args:
fixed_sql_prompt (string): Fixed SQL prompt
user_query (string): User query
Returns:
string: Combined prompt
"""
final_user_input = f"### A query to answer: {user_query}\nSELECT"
return fixed_sql_prompt + final_user_input
def send_to_openai(prompt):
"""Send the prompt to OpenAI
Args:
prompt (string): Prompt to send to OpenAI
Returns:
string: Response from OpenAI
"""
response = openai.Completion.create(
engine="code-davinci-002",
prompt=prompt,
temperature=0,
max_tokens=150,
top_p=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
stop=["#", ";"]
)
return response