-
Notifications
You must be signed in to change notification settings - Fork 25
/
inference_utils.py
153 lines (129 loc) · 4.88 KB
/
inference_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import subprocess
import tempfile
from typing import Mapping
from PIL import Image
import cv2
from einops import rearrange
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as transforms_F
from video_to_video.utils.logger import get_logger
logger = get_logger()
def tensor2vid(video, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
video = video.mul_(std).add_(mean)
video.clamp_(0, 1)
video = video * 255.0
images = rearrange(video, "b c f h w -> b f h w c")[0]
return images
def preprocess(input_frames):
out_frame_list = []
for pointer in range(len(input_frames)):
frame = input_frames[pointer]
frame = frame[:, :, ::-1]
frame = Image.fromarray(frame.astype("uint8")).convert("RGB")
frame = transforms_F.to_tensor(frame)
out_frame_list.append(frame)
out_frames = torch.stack(out_frame_list, dim=0)
out_frames.clamp_(0, 1)
mean = out_frames.new_tensor([0.5, 0.5, 0.5]).view(-1)
std = out_frames.new_tensor([0.5, 0.5, 0.5]).view(-1)
out_frames.sub_(mean.view(1, -1, 1, 1)).div_(std.view(1, -1, 1, 1))
return out_frames
def adjust_resolution(h, w, up_scale):
if h * w * up_scale * up_scale < 720 * 1280 * 1.5:
up_s = np.sqrt(720 * 1280 * 1.5 / (h * w))
target_h = int(up_s * h // 2 * 2)
target_w = int(up_s * w // 2 * 2)
elif h * w * up_scale * up_scale > 1152 * 2048:
up_s = np.sqrt(1152 * 2048 / (h * w))
target_h = int(up_s * h // 2 * 2)
target_w = int(up_s * w // 2 * 2)
else:
target_h = int(up_scale * h // 2 * 2)
target_w = int(up_scale * w // 2 * 2)
return (target_h, target_w)
def make_mask_cond(in_f_num, interp_f_num):
mask_cond = []
interp_cond = [-1 for _ in range(interp_f_num)]
for i in range(in_f_num):
mask_cond.append(i)
if i != in_f_num - 1:
mask_cond += interp_cond
return mask_cond
def load_prompt_list(file_path):
files = []
with open(file_path, "r") as fin:
for line in fin:
path = line.strip()
if path:
files.append(path)
return files
def load_video(vid_path):
capture = cv2.VideoCapture(vid_path)
_fps = capture.get(cv2.CAP_PROP_FPS)
_total_frame_num = capture.get(cv2.CAP_PROP_FRAME_COUNT)
pointer = 0
frame_list = []
stride = 1
while len(frame_list) < _total_frame_num:
ret, frame = capture.read()
pointer += 1
if (not ret) or (frame is None):
break
if pointer >= _total_frame_num + 1:
break
if pointer % stride == 0:
frame_list.append(frame)
capture.release()
return frame_list, _fps
def save_video(video, save_dir, file_name, fps=16.0):
output_path = os.path.join(save_dir, file_name)
images = [(img.numpy()).astype("uint8") for img in video]
temp_dir = tempfile.mkdtemp()
for fid, frame in enumerate(images):
tpth = os.path.join(temp_dir, "%06d.png" % (fid + 1))
cv2.imwrite(tpth, frame[:, :, ::-1])
tmp_path = os.path.join(save_dir, "tmp.mp4")
cmd = f"ffmpeg -y -f image2 -framerate {fps} -i {temp_dir}/%06d.png \
-vcodec libx264 -crf 17 -pix_fmt yuv420p {tmp_path}"
status, output = subprocess.getstatusoutput(cmd)
if status != 0:
logger.error(f"Save Video Error with {output}")
os.system(f"rm -rf {temp_dir}")
os.rename(tmp_path, output_path)
def collate_fn(data, device):
"""Prepare the input just before the forward function.
This method will move the tensors to the right device.
Usually this method does not need to be overridden.
Args:
data: The data out of the dataloader.
device: The device to move data to.
Returns: The processed data.
"""
from torch.utils.data.dataloader import default_collate
def get_class_name(obj):
return obj.__class__.__name__
if isinstance(data, dict) or isinstance(data, Mapping):
return type(data)({k: collate_fn(v, device) if k != "img_metas" else v for k, v in data.items()})
elif isinstance(data, (tuple, list)):
if 0 == len(data):
return torch.Tensor([])
if isinstance(data[0], (int, float)):
return default_collate(data).to(device)
else:
return type(data)(collate_fn(v, device) for v in data)
elif isinstance(data, np.ndarray):
if data.dtype.type is np.str_:
return data
else:
return collate_fn(torch.from_numpy(data), device)
elif isinstance(data, torch.Tensor):
return data.to(device)
elif isinstance(data, (bytes, str, int, float, bool, type(None))):
return data
else:
raise ValueError(f"Unsupported data type {type(data)}")