-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
83 lines (71 loc) · 1.97 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
import os
import re
import json
# load api keys
load_dotenv("secret/.env")
# init model
model = ChatOpenAI(
model="gpt-4o-mini",
temperature=0,
)
# get prompts
with open("llm/prompts/system.txt", "r") as file:
system_txt = file.read()
with open("llm/prompts/template.txt", "r") as file:
template_txt = file.read()
with open("llm/prompts/schema.json", "r") as file:
schema_json = file.read()
with open("llm/prompts/examples.json", "r") as file:
examples_json = file.read()
with open("llm/prompts/test_image.txt", "r") as file:
image_data = file.read()
prompt = PromptTemplate(
input_variables=["schema_json", "examples_json"],
template=template_txt,
)
prompt_txt = prompt.format(
schema_json=schema_json,
examples_json=examples_json,
)
# prepare messages
messages = [
SystemMessage(content=system_txt),
HumanMessage(
content=[
{
"type": "text",
"text": prompt_txt,
},
{
"type": "image_url",
"image_url": {
"url": image_data,
"detail": "low",
},
},
]
),
]
# initialize parser
parser = StrOutputParser()
# prepare chain
chain = model | parser
res = chain.invoke(messages)
print(res)
# extract JSON from the model's response
try:
# Find the JSON portion of the response
json_match = re.search(r"\{.*\}", res, re.DOTALL)
if json_match:
json_output = json.loads(json_match.group())
print(json_output)
else:
print("No JSON found in the response.")
except json.JSONDecodeError:
print("Failed to parse JSON response.")