-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrunner_inference_canny.py
255 lines (219 loc) · 8.48 KB
/
runner_inference_canny.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import os, sys
import time, re, json, shutil
import requests, subprocess, random
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-mr", "--model_req",
help="DeSOTA Request as yaml file path",
type=str)
parser.add_argument("-mru", "--model_res_url",
help="DeSOTA API Result URL. Recognize path instead of url for desota tests", # check how is atribuited the dev_mode variable in main function
type=str)
from requests.adapters import HTTPAdapter, Retry
s = requests.Session()
retries = Retry(total=5,
backoff_factor=0.1,
status_forcelist=[ 500, 502, 503, 504 ])
s.mount('https://', HTTPAdapter(max_retries=retries))
DEBUG = False
CURRENT_PATH = os.path.dirname(os.path.realpath(__file__))
# DeSOTA Funcs [START]
# > Import DeSOTA Scripts
from desota import detools
# > Grab DeSOTA Paths
USER_SYS = detools.get_platform()
APP_PATH = os.path.dirname(os.path.realpath(__file__))
TMP_PATH = os.path.join(CURRENT_PATH, f"tmp")
#IN_PATH = os.path.join(CURRENT_PATH, f"in")
# > USER_PATH
if USER_SYS == "win":
path_split = str(APP_PATH).split("\\")
desota_idx = [ps.lower() for ps in path_split].index("desota")
USER=path_split[desota_idx-1]
USER_PATH = "\\".join(path_split[:desota_idx])
elif USER_SYS == "lin":
path_split = str(APP_PATH).split("/")
desota_idx = [ps.lower() for ps in path_split].index("desota")
USER=path_split[desota_idx-1]
USER_PATH = "/".join(path_split[:desota_idx])
DESOTA_ROOT_PATH = os.path.join(USER_PATH, "Desota")
CONFIG_PATH = os.path.join(DESOTA_ROOT_PATH, "Configs")
SERV_CONF_PATH = os.path.join(CONFIG_PATH, "services.config.yaml")
# DeSOTA Funcs [END]
def main(args):
'''
return codes:
0 = SUCESS
1 = INPUT ERROR
2 = OUTPUT ERROR
3 = API RESPONSE ERROR
9 = REINSTALL MODEL (critical fail)
'''
# Time when grabed
_report_start_time = time.time()
start_time = int(_report_start_time)
#---INPUT---# TODO (PRO ARGS)
_resnum = 5
#---INPUT---#
# DeSOTA Model Request
model_request_dict = detools.get_model_req(args.model_req)
# API Response URL
send_task_url = args.model_res_url
# TARGET File Path
out_filename = f"canny-video-{start_time}.mp4"
out_filepath = os.path.join(TMP_PATH, out_filename)
out_urls = detools.get_url_from_str(send_task_url)
if len(out_urls)==0:
dev_mode = True
report_path = send_task_url
else:
dev_mode = False
report_path = out_urls[0]
# Get text from request
_req_text = detools.get_request_text(model_request_dict)
if isinstance(_req_text, list):
_req_text = " OR ".join(_req_text)
if DEBUG:
with open(os.path.join(APP_PATH, "debug.txt"), "w") as fw:
fw.writelines([
f"INPUT: '{_req_text}'\n",
f"IsINPUT?: {True if _req_text else False}\n"
])
# TODO Get VIDEO from request TODO
##TODO##
_req_video = detools.get_request_video(model_request_dict) ##TODO##
#print(model_request_dict)
if isinstance(_req_video, list):
_req_video = str(_req_video[0])
#REMOVE OLD INPUTS
#try:
# shutil.rmtree(IN_PATH)
#except OSError as e:
# print("Error: %s - %s." % (e.filename, e.strerror))
#os.makedirs(args.IN_PATH, exist_ok=True)
filename = os.path.basename(_req_video)
file_ext = os.path.splitext(filename)[1]
# INPUT File Path
#in_filename = f'video-input.{file_ext}'
#in_filepath = os.path.join(IN_PATH, in_filename)
##TODO##
#with requests.get(_req_video, stream=True) as r:
# with open(in_filepath, 'wb') as f:
# shutil.copyfileobj(r.raw, f)
# Run Model
if _req_text:
_model_run = os.path.join(APP_PATH, "main.py")
if USER_SYS == "win":
_model_runner_py = os.path.join(APP_PATH, "env", "python.exe")
elif USER_SYS == "lin":
_model_runner_py = os.path.join(APP_PATH, "env", "bin", "python3")
targs = {}
if 'prompt' in targs:
if targs['prompt'] == '-=#{([$argument$])}#=-':
targs['prompt'] = _req_text
else:
targs['prompt'] = _req_text
if 'condition' not in targs:
targs['condition'] = "canny"
else:
if targs['condition'] == "pose":
targs['condition'] = "openpose"
if targs['condition'] == "scribble":
targs['condition'] = "scribble_hedsafe"
if targs['condition'] == "softedge":
targs['condition'] = "softedge_hedsafe"
if targs['condition'] == "face-geometry":
targs['condition'] = "mediapipe_face"
if targs['condition'] == "normals":
targs['condition'] = "normal_bae"
if targs['condition'] == "geometry":
targs['condition'] = "mlsd"
if targs['condition'] == "lineart":
targs['condition'] = "lineart_realistic"
if targs['condition'] == "anime":
targs['condition'] = "lineart_anime"
if targs['condition'] == "canny":
targs['condition'] = "canny"
if targs['condition'] == "shuffle":
targs['condition'] = "shuffle"
if targs['condition'] == "depth":
targs['condition'] = "depth_midas"
if 'width' not in targs:
targs['width'] = "256"
if 'height' not in targs:
targs['height'] = "256"
if 'video_length' not in targs:
targs['video_length'] = "15"
if 'frame_rate' not in targs:
targs['frame_rate'] = "2"
if 'seed' in targs:
targs['seed'] = str(targs.seed) if targs.seed.isdigit() else str(random.randint(1, 1000000))
if 'seed' not in targs:
targs['seed'] = str(random.randint(1, 1000000))
targs['version'] = "v11"
print(targs)
le_cmd = [
_model_runner_py, _model_run,
"--prompt", f'"{targs["prompt"]}"',
"--resnum", str(_resnum),
"--respath", str(out_filename),
"--video_path", str(_req_video),
"--condition", str(targs["condition"]),
"--video_length", str(targs["video_length"]),
"--width", str(targs['width']),
"--height", str(targs['height']),
"--frame_rate", str(targs["frame_rate"]),
"--version", str(targs["version"]), # You may need to adjust this
"--seed", str(targs["seed"]) ,
#"--is_long_video" if targs.is_long_video else "",
]
print(" ".join(le_cmd))
_sproc = subprocess.Popen(
le_cmd
)
while True:
# TODO: implement model timeout
_ret_code = _sproc.poll()
if _ret_code != None:
break
else:
print(f"[ ERROR ] -> Desotacontrolvideo Request Failed: No Input found")
exit(1)
if not os.path.isfile(out_filepath):
print(f"[ ERROR ] -> Desotacontrolvideo Request Failed: No Output found")
exit(2)
if dev_mode:
if not report_path.endswith(".json"):
report_path += ".json"
with open(report_path, "w") as rw:
json.dump(
{
"Model Result Path": out_filepath,
"Processing Time": time.time() - _report_start_time
},
rw,
indent=2
)
detools.user_chown(report_path)
detools.user_chown(out_filepath)
print(f"Path to report:\n\t{report_path}")
print(f"[ INFO ] -> DesotaControlVideo Made it!")
# DeSOTA API Response Preparation
files = []
with open(out_filepath, 'rb') as fr:
files.append(('upload[]', fr))
# DeSOTA API Response Post
send_task = s.post(url = send_task_url, files=files)
print(f"[ INFO ] -> DeSOTA API Upload:{json.dumps(send_task.json(), indent=2)}")
# Delete temporary file
os.remove(out_filepath)
if send_task.status_code != 200:
print(f"[ ERROR ] -> Desotacontrolvideo Post Failed (Info):\nfiles: {files}\nResponse Code: {send_task.status_code}")
exit(3)
print("TASK OK!")
exit(0)
if __name__ == "__main__":
args = parser.parse_args()
if not args.model_req or not args.model_res_url:
raise EnvironmentError()
main(args)