import time from threading import Thread import pandas as pd from ultralytics import Explorer from ultralytics.utils import ROOT, SETTINGS from ultralytics.utils.checks import check_requirements check_requirements('streamlit>=1.29.0') check_requirements('streamlit-select>=0.2') import streamlit as st from streamlit_select import image_select def _get_explorer(): exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model')) thread = Thread(target=exp.create_embeddings_table, kwargs={'force': st.session_state.get('force_recreate_embeddings')}) thread.start() progress_bar = st.progress(0, text='Creating embeddings table...') while exp.progress < 1: time.sleep(0.1) progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%') thread.join() st.session_state['explorer'] = exp progress_bar.empty() def init_explorer_form(): datasets = ROOT / 'cfg' / 'datasets' ds = [d.name for d in datasets.glob('*.yaml')] models = [ 'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt', 'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt', 'yolov8l-pose.pt', 'yolov8x-pose.pt'] with st.form(key='explorer_init_form'): col1, col2 = st.columns(2) with col1: st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml')) with col2: st.selectbox('Select model', models, key='model') st.checkbox('Force recreate embeddings', key='force_recreate_embeddings') st.form_submit_button('Explore', on_click=_get_explorer) def query_form(): with st.form('query_form'): col1, col2 = st.columns([0.8, 0.2]) with col1: st.text_input('Query', "WHERE labels LIKE '%person%' AND labels LIKE '%dog%'", label_visibility='collapsed', key='query') with col2: st.form_submit_button('Query', on_click=run_sql_query) def ai_query_form(): with st.form('ai_query_form'): col1, col2 = st.columns([0.8, 0.2]) with col1: st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query') with col2: st.form_submit_button('Ask AI', on_click=run_ai_query) def find_similar_imgs(imgs): exp = st.session_state['explorer'] similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow') paths = similar.to_pydict()['im_file'] st.session_state['imgs'] = paths def similarity_form(selected_imgs): st.write('Similarity Search') with st.form('similarity_form'): subcol1, subcol2 = st.columns([1, 1]) with subcol1: st.number_input('limit', min_value=None, max_value=None, value=25, label_visibility='collapsed', key='limit') with subcol2: disabled = not len(selected_imgs) st.write('Selected: ', len(selected_imgs)) st.form_submit_button( 'Search', disabled=disabled, on_click=find_similar_imgs, args=(selected_imgs, ), ) if disabled: st.error('Select at least one image to search.') # def persist_reset_form(): # with st.form("persist_reset"): # col1, col2 = st.columns([1, 1]) # with col1: # st.form_submit_button("Reset", on_click=reset) # # with col2: # st.form_submit_button("Persist", on_click=update_state, args=("PERSISTING", True)) def run_sql_query(): st.session_state['error'] = None query = st.session_state.get('query') if query.rstrip().lstrip(): exp = st.session_state['explorer'] res = exp.sql_query(query, return_type='arrow') st.session_state['imgs'] = res.to_pydict()['im_file'] def run_ai_query(): if not SETTINGS['openai_api_key']: st.session_state[ 'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."' return st.session_state['error'] = None query = st.session_state.get('ai_query') if query.rstrip().lstrip(): exp = st.session_state['explorer'] res = exp.ask_ai(query) if not isinstance(res, pd.DataFrame) or res.empty: st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.' return st.session_state['imgs'] = res['im_file'].to_list() def reset_explorer(): st.session_state['explorer'] = None st.session_state['imgs'] = None st.session_state['error'] = None def utralytics_explorer_docs_callback(): with st.container(border=True): st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg', width=100) st.markdown( "

This demo is built using Ultralytics Explorer API. Visit API docs to try examples & learn more

", unsafe_allow_html=True, help=None) st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/') def layout(): st.set_page_config(layout='wide', initial_sidebar_state='collapsed') st.markdown("

Ultralytics Explorer Demo

", unsafe_allow_html=True) if st.session_state.get('explorer') is None: init_explorer_form() return st.button(':arrow_backward: Select Dataset', on_click=reset_explorer) exp = st.session_state.get('explorer') col1, col2 = st.columns([0.75, 0.25], gap='small') imgs = [] if st.session_state.get('error'): st.error(st.session_state['error']) else: imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file'] total_imgs, selected_imgs = len(imgs), [] with col1: subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5) with subcol1: st.write('Max Images Displayed:') with subcol2: num = st.number_input('Max Images Displayed', min_value=0, max_value=total_imgs, value=min(500, total_imgs), key='num_imgs_displayed', label_visibility='collapsed') with subcol3: st.write('Start Index:') with subcol4: start_idx = st.number_input('Start Index', min_value=0, max_value=total_imgs, value=0, key='start_index', label_visibility='collapsed') with subcol5: reset = st.button('Reset', use_container_width=False, key='reset') if reset: st.session_state['imgs'] = None st.experimental_rerun() query_form() ai_query_form() if total_imgs: imgs_displayed = imgs[start_idx:start_idx + num] selected_imgs = image_select( f'Total samples: {total_imgs}', images=imgs_displayed, use_container_width=False, # indices=[i for i in range(num)] if select_all else None, ) with col2: similarity_form(selected_imgs) # display_labels = st.checkbox("Labels", value=False, key="display_labels") utralytics_explorer_docs_callback() if __name__ == '__main__': layout()