From 62742c21d2f90639b81e26c7d46e76284d98c564 Mon Sep 17 00:00:00 2001
From: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Date: Thu, 25 Jan 2024 22:05:46 +0530
Subject: [PATCH] Adds toggle displaying labels in GUI and verbose log on start
 (#7804)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
---
 ultralytics/cfg/__init__.py           |  1 +
 ultralytics/data/explorer/gui/dash.py | 26 +++++++++++++++++++++++---
 2 files changed, 24 insertions(+), 3 deletions(-)

diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py
index b9069477..6038baf0 100644
--- a/ultralytics/cfg/__init__.py
+++ b/ultralytics/cfg/__init__.py
@@ -396,6 +396,7 @@ def handle_yolo_settings(args: List[str]) -> None:
 def handle_explorer():
     """Open the Ultralytics Explorer GUI."""
     checks.check_requirements("streamlit")
+    LOGGER.info(f"💡 Loading Explorer dashboard...")
     subprocess.run(["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"])
 
 
diff --git a/ultralytics/data/explorer/gui/dash.py b/ultralytics/data/explorer/gui/dash.py
index 1ef7fe42..36ab3c0d 100644
--- a/ultralytics/data/explorer/gui/dash.py
+++ b/ultralytics/data/explorer/gui/dash.py
@@ -9,7 +9,7 @@ from ultralytics import Explorer
 from ultralytics.utils import ROOT, SETTINGS
 from ultralytics.utils.checks import check_requirements
 
-check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.2"))
+check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.3"))
 
 import streamlit as st
 from streamlit_select import image_select
@@ -94,6 +94,7 @@ def find_similar_imgs(imgs):
     similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow")
     paths = similar.to_pydict()["im_file"]
     st.session_state["imgs"] = paths
+    st.session_state["res"] = similar
 
 
 def similarity_form(selected_imgs):
@@ -137,6 +138,7 @@ def run_sql_query():
         exp = st.session_state["explorer"]
         res = exp.sql_query(query, return_type="arrow")
         st.session_state["imgs"] = res.to_pydict()["im_file"]
+        st.session_state["res"] = res
 
 
 def run_ai_query():
@@ -155,6 +157,7 @@ def run_ai_query():
             st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
             return
         st.session_state["imgs"] = res["im_file"].to_list()
+        st.session_state["res"] = res
 
 
 def reset_explorer():
@@ -195,7 +198,11 @@ def layout():
     if st.session_state.get("error"):
         st.error(st.session_state["error"])
     else:
-        imgs = st.session_state.get("imgs") or exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"]
+        if st.session_state.get("imgs"):
+            imgs = st.session_state.get("imgs")
+        else:
+            imgs = exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"]
+            st.session_state["res"] = exp.table.to_arrow()
     total_imgs, selected_imgs = len(imgs), []
     with col1:
         subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
@@ -230,17 +237,30 @@ def layout():
         query_form()
         ai_query_form()
         if total_imgs:
+            labels, boxes, masks, kpts, classes = None, None, None, None, None
+            task = exp.model.task
+            if st.session_state.get("display_labels"):
+                labels = st.session_state.get("res").to_pydict()["labels"][start_idx : start_idx + num]
+                boxes = st.session_state.get("res").to_pydict()["bboxes"][start_idx : start_idx + num]
+                masks = st.session_state.get("res").to_pydict()["masks"][start_idx : start_idx + num]
+                kpts = st.session_state.get("res").to_pydict()["keypoints"][start_idx : start_idx + num]
+                classes = st.session_state.get("res").to_pydict()["cls"][start_idx : start_idx + num]
             imgs_displayed = imgs[start_idx : start_idx + num]
             selected_imgs = image_select(
                 f"Total samples: {total_imgs}",
                 images=imgs_displayed,
                 use_container_width=False,
                 # indices=[i for i in range(num)] if select_all else None,
+                labels=labels,
+                classes=classes,
+                bboxes=boxes,
+                masks=masks if task == "segment" else None,
+                kpts=kpts if task == "pose" else None,
             )
 
     with col2:
         similarity_form(selected_imgs)
-        # display_labels = st.checkbox("Labels", value=False, key="display_labels")
+        display_labels = st.checkbox("Labels", value=False, key="display_labels")
         utralytics_explorer_docs_callback()