Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poorman's optimizations to save RAM usage #135

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/config/argument_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ArgumentConfig(PrintableConfig):
source_image: Annotated[str, tyro.conf.arg(aliases=["-s"])] = make_abs_path('../../assets/examples/source/s6.jpg') # path to the source portrait
driving_info: Annotated[str, tyro.conf.arg(aliases=["-d"])] = make_abs_path('../../assets/examples/driving/d0.mp4') # path to driving video or template (.pkl format)
output_dir: Annotated[str, tyro.conf.arg(aliases=["-o"])] = 'animations/' # directory to save output video
lazy_loading: bool = False # load input images on demand one by one frame to minimize RAM usage

########## inference arguments ##########
flag_use_half_precision: bool = True # whether to use half precision (FP16). If black boxes appear, it might be due to GPU incompatibility; set to False.
Expand Down
4 changes: 2 additions & 2 deletions src/live_portrait_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def execute(self, args: ArgumentConfig):
log(f'The FPS of {args.driving_info} is: {output_fps}')

log(f"Load video file (mp4 mov avi etc...): {args.driving_info}")
driving_rgb_lst = load_driving_info(args.driving_info)
driving_rgb_lst = load_driving_info(args.driving_info, lazy=args.lazy_loading)

######## make motion template ########
log("Start making motion template...")
Expand Down Expand Up @@ -220,7 +220,7 @@ def execute(self, args: ArgumentConfig):
# driving frame | source image | generation, or source image | generation
frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, img_crop_256x256, I_p_lst)
wfp_concat = osp.join(args.output_dir, f'{basename(args.source_image)}--{basename(args.driving_info)}_concat.mp4')
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps)
images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps, frames_count=len(driving_rgb_crop_256x256_lst))

if flag_has_audio:
# final result with concact
Expand Down
30 changes: 28 additions & 2 deletions src/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,39 @@ def load_image_rgb(image_path: str):
def load_driving_info(driving_info):
driving_video_ori = []

from typing import Iterator

class LazyVideoFramesIterator:
def __init__(self, reader: 'imageio.Reader'):
self.data_iter = reader.iter_data()

def __iter__(self):
return self

def __next__(self):
return next(self.data_iter)

class LazyVideoFramesLoader:
def __init__(self, video_path: str) -> None:
self.video_path = video_path
self.reader = imageio.get_reader(video_path, "ffmpeg")

def __iter__(self) -> Iterator[LazyVideoFramesIterator]:
return LazyVideoFramesIterator(self.reader)

def __getitem__(self, key):
raise Exception("Indexing isn't implemented for lazy frames loading")

def load_images_from_directory(directory):
image_paths = sorted(glob(osp.join(directory, '*.png')) + glob(osp.join(directory, '*.jpg')))
return [load_image_rgb(im_path) for im_path in image_paths]

def load_images_from_video(file_path):
reader = imageio.get_reader(file_path, "ffmpeg")
return [image for _, image in enumerate(reader)]
if lazy:
return LazyVideoFramesLoader(file_path)
else:
reader = imageio.get_reader(file_path, "ffmpeg")
return [image for _, image in enumerate(reader)]

if osp.isdir(driving_info):
driving_video_ori = load_images_from_directory(driving_info)
Expand Down
18 changes: 11 additions & 7 deletions src/utils/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,18 @@ def images2video(images, wfp, **kwargs):
codec=codec, quality=quality, ffmpeg_params=ffmpeg_params, pixelformat=pixelformat, macro_block_size=macro_block_size
)

n = len(images)
n = len(images) if hasattr(images, '__len__') else kwargs.get('frames_count')
img_it = iter(images)
for i in track(range(n), description='Writing', transient=True):
try:
img = next(img_it)
except StopIteration:
break

if image_mode.lower() == 'bgr':
writer.append_data(images[i][..., ::-1])
writer.append_data(img[..., ::-1])
else:
writer.append_data(images[i])
writer.append_data(img)

writer.close()

Expand Down Expand Up @@ -82,10 +88,9 @@ def blend(img: np.ndarray, mask: np.ndarray, background_color=(255, 255, 255)):

def concat_frames(driving_image_lst, source_image, I_p_lst):
# TODO: add more concat style, e.g., left-down corner driving
out_lst = []
h, w, _ = I_p_lst[0].shape

for idx, _ in track(enumerate(I_p_lst), total=len(I_p_lst), description='Concatenating result...'):
for idx, _ in enumerate(I_p_lst):
I_p = I_p_lst[idx]
source_image_resized = cv2.resize(source_image, (w, h))

Expand All @@ -96,8 +101,7 @@ def concat_frames(driving_image_lst, source_image, I_p_lst):
driving_image_resized = cv2.resize(driving_image, (w, h))
out = np.hstack((driving_image_resized, source_image_resized, I_p))

out_lst.append(out)
return out_lst
yield out


class VideoWriter:
Expand Down