Added video output support for gradio (#113)

added video output support for gradio

---------

Co-authored-by: Sencer Yücel <sencer.yucel@turkai.com>
Co-authored-by: wa22 <wa22@mails.tsinghua.edu.cn>
This commit is contained in:
Sencer Yücel 2024-06-03 12:15:31 +03:00 committed by GitHub
parent 1539b5a678
commit 1cfe7a4e13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

99
app.py
View File

@ -1,30 +1,60 @@
import PIL.Image as Image
import gradio as gr import gradio as gr
import cv2
import tempfile
from ultralytics import YOLOv10 from ultralytics import YOLOv10
def predict_image(img, model_id, image_size, conf_threshold):
def yolov10_inference(image, video, model_id, image_size, conf_threshold):
model = YOLOv10.from_pretrained(f'jameslahm/{model_id}') model = YOLOv10.from_pretrained(f'jameslahm/{model_id}')
results = model.predict( if image:
source=img, results = model.predict(source=image, imgsz=image_size, conf=conf_threshold)
conf=conf_threshold, annotated_image = results[0].plot()
show_labels=True, return annotated_image[:, :, ::-1], None
show_conf=True, else:
imgsz=image_size, video_path = tempfile.mktemp(suffix=".webm")
) with open(video_path, "wb") as f:
with open(video, "rb") as g:
f.write(g.read())
for r in results: cap = cv2.VideoCapture(video_path)
im_array = r.plot() fps = cap.get(cv2.CAP_PROP_FPS)
im = Image.fromarray(im_array[..., ::-1]) frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
output_video_path = tempfile.mktemp(suffix=".webm")
out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold)
annotated_frame = results[0].plot()
out.write(annotated_frame)
cap.release()
out.release()
return None, output_video_path
def yolov10_inference_for_examples(image, model_path, image_size, conf_threshold):
annotated_image, _ = yolov10_inference(image, None, model_path, image_size, conf_threshold)
return annotated_image
return im
def app(): def app():
with gr.Blocks(): with gr.Blocks():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
image = gr.Image(type="pil", label="Image") image = gr.Image(type="pil", label="Image", visible=True)
video = gr.Video(label="Video", visible=False)
input_type = gr.Radio(
choices=["Image", "Video"],
value="Image",
label="Input Type",
)
model_id = gr.Dropdown( model_id = gr.Dropdown(
label="Model", label="Model",
choices=[ choices=[
@ -54,17 +84,34 @@ def app():
yolov10_infer = gr.Button(value="Detect Objects") yolov10_infer = gr.Button(value="Detect Objects")
with gr.Column(): with gr.Column():
output_image = gr.Image(type="pil", label="Annotated Image") output_image = gr.Image(type="numpy", label="Annotated Image", visible=True)
output_video = gr.Video(label="Annotated Video", visible=False)
def update_visibility(input_type):
image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False)
output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True)
return image, video, output_image, output_video
input_type.change(
fn=update_visibility,
inputs=[input_type],
outputs=[image, video, output_image, output_video],
)
def run_inference(image, video, model_id, image_size, conf_threshold, input_type):
if input_type == "Image":
return yolov10_inference(image, None, model_id, image_size, conf_threshold)
else:
return yolov10_inference(None, video, model_id, image_size, conf_threshold)
yolov10_infer.click( yolov10_infer.click(
fn=predict_image, fn=run_inference,
inputs=[ inputs=[image, video, model_id, image_size, conf_threshold, input_type],
image, outputs=[output_image, output_video],
model_id,
image_size,
conf_threshold,
],
outputs=[output_image],
) )
gr.Examples( gr.Examples(
@ -82,7 +129,7 @@ def app():
0.25, 0.25,
], ],
], ],
fn=predict_image, fn=yolov10_inference_for_examples,
inputs=[ inputs=[
image, image,
model_id, model_id,