Skip to content

Commit

Permalink
add tabbed interface
Browse files Browse the repository at this point in the history
  • Loading branch information
CuriousDolphin committed Nov 3, 2023
1 parent a1c9872 commit aa13994
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 9 deletions.
35 changes: 29 additions & 6 deletions yologp/frame_extractor_gradio_app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Optional
from pytube import YouTube
import gradio as gr
from pathlib import Path
import os
import cv2
from supervision import (
ImageSink,
get_video_frames_generator,
Expand All @@ -10,11 +12,24 @@
from tqdm import tqdm
from helpers import zoom_center
import shutil
import numpy as np

data_path = Path(__file__).parent.parent / "data"
print("DATA PATH: ", data_path)


class MyImageSink(ImageSink):
def save_image(
self, image: np.ndarray, image_name: Optional[str] = None, quality: int = 70
):
if image_name is None:
image_name = self.image_name_pattern.format(self.image_count)

image_path = os.path.join(self.target_dir_path, image_name)
cv2.imwrite(image_path, image, [cv2.IMWRITE_JPEG_QUALITY, quality])
self.image_count += 1


def download_youtube_url(url, out_dir) -> str:
yt = YouTube(url=url)
files = yt.streams.filter(file_extension="mp4", only_video=True)
Expand All @@ -31,6 +46,7 @@ def extract_frames(
start,
end,
resize_w,
quality,
zoom,
progress=gr.Progress(track_tqdm=True),
):
Expand All @@ -44,19 +60,25 @@ def extract_frames(
video_name = str(v_path.stem).replace(" ", "")
target_dir = Path(f"{data_path}/{video_name}_frames")
cont = 0
with ImageSink(
with MyImageSink(
target_dir_path=target_dir,
image_name_pattern="image_{:05d}.jpg",
overwrite=True,
) as sink:
for image in tqdm(
get_video_frames_generator(
source_path=str(v_path), stride=stride, start=start
source_path=str(v_path),
stride=stride,
start=start,
end=end if end != -1 else None,
)
):
if zoom > 1:
image = zoom_center(img=image.copy(), zoom_factor=zoom)
sink.save_image(image=image.copy())
sink.save_image(
image=image.copy(),
quality=quality,
)
cont += 1
progress(0.8, "Zipping..")
print("Target_dir", target_dir)
Expand All @@ -81,17 +103,18 @@ def extract_frames(
gr.Number(label="Start Frame", value=0),
gr.Number(label="End Frame", value=-1),
gr.Number(label="Resize Width (px)", value=-1),
gr.Slider(label="Quality", minimum=0, maximum=100, value=70),
gr.Slider(label="Image Zoom", minimum=1.0, maximum=2.99, value=1.4),
]
outputs = [gr.Gallery(label="preview"), gr.File()]
interface = gr.Interface(
frame_ext_interface = gr.Interface(
fn=extract_frames,
inputs=inputs,
outputs=outputs,
examples=[["https://www.youtube.com/watch?v=XDhjS_fzhsQ"]],
allow_flagging=False,
allow_flagging="never",
)


if __name__ == "__main__":
interface.queue(max_size=10).launch(server_name="0.0.0.0")
frame_ext_interface.queue(max_size=10).launch(server_name="0.0.0.0")
13 changes: 13 additions & 0 deletions yologp/gradio_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import gradio as gr

from frame_extractor_gradio_app import frame_ext_interface
from inference_gradio_app import inference_interface


tabbed_interface = gr.TabbedInterface(
interface_list=[inference_interface, frame_ext_interface],
tab_names=["Inference", "Extract Frame"],
)

if __name__ == "__main__":
tabbed_interface.queue(max_size=10).launch(server_name="0.0.0.0")
25 changes: 22 additions & 3 deletions yologp/inference_gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def inference(image, conf: float, iou: float, progress=gr.Progress()):
return frame


with gr.Blocks() as inference_app:
""" with gr.Blocks() as inference_app:
gr.Markdown("# 🏍️ YoloGP: Motogp tracker")
with gr.Row():
with gr.Column():
Expand All @@ -60,7 +60,26 @@ def inference(image, conf: float, iou: float, progress=gr.Progress()):
with gr.Column():
output_im = gr.Image()
button.click(fn=inference, inputs=[image, conf, iou], outputs=output_im)
button.click(fn=inference, inputs=[image, conf, iou], outputs=output_im) """


inference_interface = gr.Interface(
description="# 🏍️ YoloGP: Motogp tracker (YoloV8 nano, detection & segmentation)",
fn=inference,
inputs=[
gr.Image(),
gr.Slider(label="Confidence", minimum=0, maximum=0.99, value=0.3),
gr.Slider(label="IoU", minimum=0, maximum=0.99, value=0.45),
],
outputs=[gr.Image()],
examples=[
["./assets/Rossi_Lorenzo_Catalunya2009.png"],
["./assets/sample1.png"],
],
allow_flagging="never",
)


if __name__ == "__main__":
inference_app.queue(max_size=10).launch(server_name="0.0.0.0")
inference_interface.queue().launch(server_name="0.0.0.0")
# inference_app.queue(max_size=10).launch(server_name="0.0.0.0")

0 comments on commit aa13994

Please sign in to comment.