From aa139946dff5506b45d3c61e55575258a5b3e7a0 Mon Sep 17 00:00:00 2001 From: Ivan Murabito <36967518+CuriousDolphin@users.noreply.github.com> Date: Fri, 3 Nov 2023 10:18:49 +0100 Subject: [PATCH] add tabbed interface --- yologp/frame_extractor_gradio_app.py | 35 +++++++++++++++++++++++----- yologp/gradio_app.py | 13 +++++++++++ yologp/inference_gradio_app.py | 25 +++++++++++++++++--- 3 files changed, 64 insertions(+), 9 deletions(-) create mode 100644 yologp/gradio_app.py diff --git a/yologp/frame_extractor_gradio_app.py b/yologp/frame_extractor_gradio_app.py index 41797a1..ffc5158 100644 --- a/yologp/frame_extractor_gradio_app.py +++ b/yologp/frame_extractor_gradio_app.py @@ -1,7 +1,9 @@ +from typing import Optional from pytube import YouTube import gradio as gr from pathlib import Path import os +import cv2 from supervision import ( ImageSink, get_video_frames_generator, @@ -10,11 +12,24 @@ from tqdm import tqdm from helpers import zoom_center import shutil +import numpy as np data_path = Path(__file__).parent.parent / "data" print("DATA PATH: ", data_path) +class MyImageSink(ImageSink): + def save_image( + self, image: np.ndarray, image_name: Optional[str] = None, quality: int = 70 + ): + if image_name is None: + image_name = self.image_name_pattern.format(self.image_count) + + image_path = os.path.join(self.target_dir_path, image_name) + cv2.imwrite(image_path, image, [cv2.IMWRITE_JPEG_QUALITY, quality]) + self.image_count += 1 + + def download_youtube_url(url, out_dir) -> str: yt = YouTube(url=url) files = yt.streams.filter(file_extension="mp4", only_video=True) @@ -31,6 +46,7 @@ def extract_frames( start, end, resize_w, + quality, zoom, progress=gr.Progress(track_tqdm=True), ): @@ -44,19 +60,25 @@ def extract_frames( video_name = str(v_path.stem).replace(" ", "") target_dir = Path(f"{data_path}/{video_name}_frames") cont = 0 - with ImageSink( + with MyImageSink( target_dir_path=target_dir, image_name_pattern="image_{:05d}.jpg", overwrite=True, ) as sink: for image in tqdm( get_video_frames_generator( - source_path=str(v_path), stride=stride, start=start + source_path=str(v_path), + stride=stride, + start=start, + end=end if end != -1 else None, ) ): if zoom > 1: image = zoom_center(img=image.copy(), zoom_factor=zoom) - sink.save_image(image=image.copy()) + sink.save_image( + image=image.copy(), + quality=quality, + ) cont += 1 progress(0.8, "Zipping..") print("Target_dir", target_dir) @@ -81,17 +103,18 @@ def extract_frames( gr.Number(label="Start Frame", value=0), gr.Number(label="End Frame", value=-1), gr.Number(label="Resize Width (px)", value=-1), + gr.Slider(label="Quality", minimum=0, maximum=100, value=70), gr.Slider(label="Image Zoom", minimum=1.0, maximum=2.99, value=1.4), ] outputs = [gr.Gallery(label="preview"), gr.File()] -interface = gr.Interface( +frame_ext_interface = gr.Interface( fn=extract_frames, inputs=inputs, outputs=outputs, examples=[["https://www.youtube.com/watch?v=XDhjS_fzhsQ"]], - allow_flagging=False, + allow_flagging="never", ) if __name__ == "__main__": - interface.queue(max_size=10).launch(server_name="0.0.0.0") + frame_ext_interface.queue(max_size=10).launch(server_name="0.0.0.0") diff --git a/yologp/gradio_app.py b/yologp/gradio_app.py new file mode 100644 index 0000000..a3a3ca8 --- /dev/null +++ b/yologp/gradio_app.py @@ -0,0 +1,13 @@ +import gradio as gr + +from frame_extractor_gradio_app import frame_ext_interface +from inference_gradio_app import inference_interface + + +tabbed_interface = gr.TabbedInterface( + interface_list=[inference_interface, frame_ext_interface], + tab_names=["Inference", "Extract Frame"], +) + +if __name__ == "__main__": + tabbed_interface.queue(max_size=10).launch(server_name="0.0.0.0") diff --git a/yologp/inference_gradio_app.py b/yologp/inference_gradio_app.py index 1e2e157..6d8ee19 100644 --- a/yologp/inference_gradio_app.py +++ b/yologp/inference_gradio_app.py @@ -40,7 +40,7 @@ def inference(image, conf: float, iou: float, progress=gr.Progress()): return frame -with gr.Blocks() as inference_app: +""" with gr.Blocks() as inference_app: gr.Markdown("# 🏍️ YoloGP: Motogp tracker") with gr.Row(): with gr.Column(): @@ -60,7 +60,26 @@ def inference(image, conf: float, iou: float, progress=gr.Progress()): with gr.Column(): output_im = gr.Image() - button.click(fn=inference, inputs=[image, conf, iou], outputs=output_im) + button.click(fn=inference, inputs=[image, conf, iou], outputs=output_im) """ + + +inference_interface = gr.Interface( + description="# 🏍️ YoloGP: Motogp tracker (YoloV8 nano, detection & segmentation)", + fn=inference, + inputs=[ + gr.Image(), + gr.Slider(label="Confidence", minimum=0, maximum=0.99, value=0.3), + gr.Slider(label="IoU", minimum=0, maximum=0.99, value=0.45), + ], + outputs=[gr.Image()], + examples=[ + ["./assets/Rossi_Lorenzo_Catalunya2009.png"], + ["./assets/sample1.png"], + ], + allow_flagging="never", +) + if __name__ == "__main__": - inference_app.queue(max_size=10).launch(server_name="0.0.0.0") + inference_interface.queue().launch(server_name="0.0.0.0") + # inference_app.queue(max_size=10).launch(server_name="0.0.0.0")