diff --git a/app.py b/app.py index 5598af6d..6fddcbaa 100644 --- a/app.py +++ b/app.py @@ -1,30 +1,60 @@ -import PIL.Image as Image import gradio as gr - +import cv2 +import tempfile from ultralytics import YOLOv10 -def predict_image(img, model_id, image_size, conf_threshold): + +def yolov10_inference(image, video, model_id, image_size, conf_threshold): model = YOLOv10.from_pretrained(f'jameslahm/{model_id}') - results = model.predict( - source=img, - conf=conf_threshold, - show_labels=True, - show_conf=True, - imgsz=image_size, - ) + if image: + results = model.predict(source=image, imgsz=image_size, conf=conf_threshold) + annotated_image = results[0].plot() + return annotated_image[:, :, ::-1], None + else: + video_path = tempfile.mktemp(suffix=".webm") + with open(video_path, "wb") as f: + with open(video, "rb") as g: + f.write(g.read()) - for r in results: - im_array = r.plot() - im = Image.fromarray(im_array[..., ::-1]) + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) + frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + output_video_path = tempfile.mktemp(suffix=".webm") + out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'vp80'), fps, (frame_width, frame_height)) + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + results = model.predict(source=frame, imgsz=image_size, conf=conf_threshold) + annotated_frame = results[0].plot() + out.write(annotated_frame) + + cap.release() + out.release() + + return None, output_video_path + + +def yolov10_inference_for_examples(image, model_path, image_size, conf_threshold): + annotated_image, _ = yolov10_inference(image, None, model_path, image_size, conf_threshold) + return annotated_image - return im def app(): with gr.Blocks(): with gr.Row(): with gr.Column(): - image = gr.Image(type="pil", label="Image") - + image = gr.Image(type="pil", label="Image", visible=True) + video = gr.Video(label="Video", visible=False) + input_type = gr.Radio( + choices=["Image", "Video"], + value="Image", + label="Input Type", + ) model_id = gr.Dropdown( label="Model", choices=[ @@ -54,17 +84,34 @@ def app(): yolov10_infer = gr.Button(value="Detect Objects") with gr.Column(): - output_image = gr.Image(type="pil", label="Annotated Image") + output_image = gr.Image(type="numpy", label="Annotated Image", visible=True) + output_video = gr.Video(label="Annotated Video", visible=False) + + def update_visibility(input_type): + image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False) + video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True) + output_image = gr.update(visible=True) if input_type == "Image" else gr.update(visible=False) + output_video = gr.update(visible=False) if input_type == "Image" else gr.update(visible=True) + + return image, video, output_image, output_video + + input_type.change( + fn=update_visibility, + inputs=[input_type], + outputs=[image, video, output_image, output_video], + ) + + def run_inference(image, video, model_id, image_size, conf_threshold, input_type): + if input_type == "Image": + return yolov10_inference(image, None, model_id, image_size, conf_threshold) + else: + return yolov10_inference(None, video, model_id, image_size, conf_threshold) + yolov10_infer.click( - fn=predict_image, - inputs=[ - image, - model_id, - image_size, - conf_threshold, - ], - outputs=[output_image], + fn=run_inference, + inputs=[image, video, model_id, image_size, conf_threshold, input_type], + outputs=[output_image, output_video], ) gr.Examples( @@ -82,7 +129,7 @@ def app(): 0.25, ], ], - fn=predict_image, + fn=yolov10_inference_for_examples, inputs=[ image, model_id, @@ -111,4 +158,4 @@ with gradio_app: with gr.Column(): app() if __name__ == '__main__': - gradio_app.launch() \ No newline at end of file + gradio_app.launch()