Ayush Chaurasia a92adf8231
Docs updates: Add Explorer to tab, YOLOv5 in Guides and Usage in Quickstart (#7438)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Haixuan Xavier Tao <tao.xavier@outlook.com>
2024-01-09 23:50:26 +01:00

227 lines
8.8 KiB
Python

# Ultralytics YOLO 🚀, AGPL-3.0 license
import time
from threading import Thread
import pandas as pd
from ultralytics import Explorer
from ultralytics.utils import ROOT, SETTINGS
from ultralytics.utils.checks import check_requirements
check_requirements(('streamlit>=1.29.0', 'streamlit-select>=0.2'))
import streamlit as st
from streamlit_select import image_select
def _get_explorer():
"""Initializes and returns an instance of the Explorer class."""
exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model'))
thread = Thread(target=exp.create_embeddings_table,
kwargs={'force': st.session_state.get('force_recreate_embeddings')})
thread.start()
progress_bar = st.progress(0, text='Creating embeddings table...')
while exp.progress < 1:
time.sleep(0.1)
progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%')
thread.join()
st.session_state['explorer'] = exp
progress_bar.empty()
def init_explorer_form():
"""Initializes an Explorer instance and creates embeddings table with progress tracking."""
datasets = ROOT / 'cfg' / 'datasets'
ds = [d.name for d in datasets.glob('*.yaml')]
models = [
'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt',
'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt',
'yolov8l-pose.pt', 'yolov8x-pose.pt']
with st.form(key='explorer_init_form'):
col1, col2 = st.columns(2)
with col1:
st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
with col2:
st.selectbox('Select model', models, key='model')
st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
st.form_submit_button('Explore', on_click=_get_explorer)
def query_form():
"""Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
with st.form('query_form'):
col1, col2 = st.columns([0.8, 0.2])
with col1:
st.text_input('Query',
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
label_visibility='collapsed',
key='query')
with col2:
st.form_submit_button('Query', on_click=run_sql_query)
def ai_query_form():
"""Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
with st.form('ai_query_form'):
col1, col2 = st.columns([0.8, 0.2])
with col1:
st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query')
with col2:
st.form_submit_button('Ask AI', on_click=run_ai_query)
def find_similar_imgs(imgs):
"""Initializes a Streamlit form for AI-based image querying with custom input."""
exp = st.session_state['explorer']
similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
paths = similar.to_pydict()['im_file']
st.session_state['imgs'] = paths
def similarity_form(selected_imgs):
"""Initializes a form for AI-based image querying with custom input in Streamlit."""
st.write('Similarity Search')
with st.form('similarity_form'):
subcol1, subcol2 = st.columns([1, 1])
with subcol1:
st.number_input('limit',
min_value=None,
max_value=None,
value=25,
label_visibility='collapsed',
key='limit')
with subcol2:
disabled = not len(selected_imgs)
st.write('Selected: ', len(selected_imgs))
st.form_submit_button(
'Search',
disabled=disabled,
on_click=find_similar_imgs,
args=(selected_imgs, ),
)
if disabled:
st.error('Select at least one image to search.')
# def persist_reset_form():
# with st.form("persist_reset"):
# col1, col2 = st.columns([1, 1])
# with col1:
# st.form_submit_button("Reset", on_click=reset)
#
# with col2:
# st.form_submit_button("Persist", on_click=update_state, args=("PERSISTING", True))
def run_sql_query():
"""Executes an SQL query and returns the results."""
st.session_state['error'] = None
query = st.session_state.get('query')
if query.rstrip().lstrip():
exp = st.session_state['explorer']
res = exp.sql_query(query, return_type='arrow')
st.session_state['imgs'] = res.to_pydict()['im_file']
def run_ai_query():
"""Execute SQL query and update session state with query results."""
if not SETTINGS['openai_api_key']:
st.session_state[
'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
return
st.session_state['error'] = None
query = st.session_state.get('ai_query')
if query.rstrip().lstrip():
exp = st.session_state['explorer']
res = exp.ask_ai(query)
if not isinstance(res, pd.DataFrame) or res.empty:
st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.'
return
st.session_state['imgs'] = res['im_file'].to_list()
def reset_explorer():
"""Resets the explorer to its initial state by clearing session variables."""
st.session_state['explorer'] = None
st.session_state['imgs'] = None
st.session_state['error'] = None
def utralytics_explorer_docs_callback():
"""Resets the explorer to its initial state by clearing session variables."""
with st.container(border=True):
st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
width=100)
st.markdown(
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
unsafe_allow_html=True,
help=None)
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/')
def layout():
"""Resets explorer session variables and provides documentation with a link to API docs."""
st.set_page_config(layout='wide', initial_sidebar_state='collapsed')
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)
if st.session_state.get('explorer') is None:
init_explorer_form()
return
st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
exp = st.session_state.get('explorer')
col1, col2 = st.columns([0.75, 0.25], gap='small')
imgs = []
if st.session_state.get('error'):
st.error(st.session_state['error'])
else:
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
total_imgs, selected_imgs = len(imgs), []
with col1:
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
with subcol1:
st.write('Max Images Displayed:')
with subcol2:
num = st.number_input('Max Images Displayed',
min_value=0,
max_value=total_imgs,
value=min(500, total_imgs),
key='num_imgs_displayed',
label_visibility='collapsed')
with subcol3:
st.write('Start Index:')
with subcol4:
start_idx = st.number_input('Start Index',
min_value=0,
max_value=total_imgs,
value=0,
key='start_index',
label_visibility='collapsed')
with subcol5:
reset = st.button('Reset', use_container_width=False, key='reset')
if reset:
st.session_state['imgs'] = None
st.experimental_rerun()
query_form()
ai_query_form()
if total_imgs:
imgs_displayed = imgs[start_idx:start_idx + num]
selected_imgs = image_select(
f'Total samples: {total_imgs}',
images=imgs_displayed,
use_container_width=False,
# indices=[i for i in range(num)] if select_all else None,
)
with col2:
similarity_form(selected_imgs)
# display_labels = st.checkbox("Labels", value=False, key="display_labels")
utralytics_explorer_docs_callback()
if __name__ == '__main__':
layout()