-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathconstraint_generation.py
158 lines (151 loc) · 6.7 KB
/
constraint_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import base64
from openai import OpenAI
import os
import cv2
import json
import parse
import numpy as np
import time
from datetime import datetime
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
class ConstraintGenerator:
def __init__(self, config):
self.config = config
self.client = OpenAI(api_key=os.environ['OPENAI_API_KEY'], base_url="https://api.chatanywhere.tech/v1")
self.base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), './vlm_query')
with open(os.path.join(self.base_dir, 'prompt_template.txt'), 'r') as f:
self.prompt_template = f.read()
def _build_prompt(self, image_path, instruction):
img_base64 = encode_image(image_path)
prompt_text = self.prompt_template.format(instruction=instruction)
# save prompt
with open(os.path.join(self.task_dir, 'prompt.txt'), 'w') as f:
f.write(prompt_text)
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": self.prompt_template.format(instruction=instruction)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{img_base64}"
}
},
]
}
]
return messages
def _parse_and_save_constraints(self, output, save_dir):
# parse into function blocks
lines = output.split("\n")
functions = dict()
for i, line in enumerate(lines):
if line.startswith("def "):
start = i
name = line.split("(")[0].split("def ")[1]
if line.startswith(" return "):
end = i
functions[name] = lines[start:end+1]
# organize them based on hierarchy in function names
groupings = dict()
for name in functions:
parts = name.split("_")[:-1] # last one is the constraint idx
key = "_".join(parts)
if key not in groupings:
groupings[key] = []
groupings[key].append(name)
# save them into files
for key in groupings:
with open(os.path.join(save_dir, f"{key}_constraints.txt"), "w") as f:
for name in groupings[key]:
f.write("\n".join(functions[name]) + "\n\n")
print(f"Constraints saved to {save_dir}")
def _parse_other_metadata(self, output):
data_dict = dict()
# find num_stages
num_stages_template = "num_stages = {num_stages}"
for line in output.split("\n"):
num_stages = parse.parse(num_stages_template, line)
if num_stages is not None:
break
if num_stages is None:
raise ValueError("num_stages not found in output")
data_dict['num_stages'] = int(num_stages['num_stages'])
# find grasp_keypoints
grasp_keypoints_template = "grasp_keypoints = {grasp_keypoints}"
for line in output.split("\n"):
grasp_keypoints = parse.parse(grasp_keypoints_template, line)
if grasp_keypoints is not None:
break
if grasp_keypoints is None:
raise ValueError("grasp_keypoints not found in output")
# convert into list of ints
grasp_keypoints = grasp_keypoints['grasp_keypoints'].replace("[", "").replace("]", "").split(",")
grasp_keypoints = [int(x.strip()) for x in grasp_keypoints]
data_dict['grasp_keypoints'] = grasp_keypoints
# find release_keypoints
release_keypoints_template = "release_keypoints = {release_keypoints}"
for line in output.split("\n"):
release_keypoints = parse.parse(release_keypoints_template, line)
if release_keypoints is not None:
break
if release_keypoints is None:
raise ValueError("release_keypoints not found in output")
# convert into list of ints
release_keypoints = release_keypoints['release_keypoints'].replace("[", "").replace("]", "").split(",")
release_keypoints = [int(x.strip()) for x in release_keypoints]
data_dict['release_keypoints'] = release_keypoints
return data_dict
def _save_metadata(self, metadata):
for k, v in metadata.items():
if isinstance(v, np.ndarray):
metadata[k] = v.tolist()
with open(os.path.join(self.task_dir, 'metadata.json'), 'w') as f:
json.dump(metadata, f)
print(f"Metadata saved to {os.path.join(self.task_dir, 'metadata.json')}")
def generate(self, img, instruction, metadata):
"""
Args:
img (np.ndarray): image of the scene (H, W, 3) uint8
instruction (str): instruction for the query
Returns:
save_dir (str): directory where the constraints
"""
# create a directory for the task
fname = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + "_" + instruction.lower().replace(" ", "_")
self.task_dir = os.path.join(self.base_dir, fname)
os.makedirs(self.task_dir, exist_ok=True)
# save query image
image_path = os.path.join(self.task_dir, 'query_img.png')
cv2.imwrite(image_path, img[..., ::-1])
# build prompt
messages = self._build_prompt(image_path, instruction)
# stream back the response
stream = self.client.chat.completions.create(model=self.config['model'],
messages=messages,
temperature=self.config['temperature'],
max_tokens=self.config['max_tokens'],
stream=True)
output = ""
start = time.time()
for chunk in stream:
print(f'[{time.time()-start:.2f}s] Querying OpenAI API...', end='\r')
if chunk.choices[0].delta.content is not None:
output += chunk.choices[0].delta.content
print(f'[{time.time()-start:.2f}s] Querying OpenAI API...Done')
# save raw output
with open(os.path.join(self.task_dir, 'output_raw.txt'), 'w') as f:
f.write(output)
# parse and save constraints
self._parse_and_save_constraints(output, self.task_dir)
# save metadata
metadata.update(self._parse_other_metadata(output))
self._save_metadata(metadata)
return self.task_dir