import time
from threading import Thread

from ultralytics import Explorer
from ultralytics.utils import ROOT
from ultralytics.utils.checks import check_requirements

check_requirements('streamlit')
check_requirements('streamlit-select>=0.2')
import streamlit as st
from streamlit_select import image_select


def _get_explorer():
    exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model'))
    thread = Thread(target=exp.create_embeddings_table,
                    kwargs={'force': st.session_state.get('force_recreate_embeddings')})
    thread.start()
    progress_bar = st.progress(0, text='Creating embeddings table...')
    while exp.progress < 1:
        time.sleep(0.1)
        progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%')
    thread.join()
    st.session_state['explorer'] = exp
    progress_bar.empty()


def init_explorer_form():
    datasets = ROOT / 'cfg' / 'datasets'
    ds = [d.name for d in datasets.glob('*.yaml')]
    models = [
        'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt',
        'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt',
        'yolov8l-pose.pt', 'yolov8x-pose.pt']
    with st.form(key='explorer_init_form'):
        col1, col2 = st.columns(2)
        with col1:
            dataset = st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
        with col2:
            model = st.selectbox('Select model', models, key='model')
        st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')

        st.form_submit_button('Explore', on_click=_get_explorer)


def query_form():
    with st.form('query_form'):
        col1, col2 = st.columns([0.8, 0.2])
        with col1:
            query = st.text_input('Query', '', label_visibility='collapsed', key='query')
        with col2:
            st.form_submit_button('Query', on_click=run_sql_query)


def find_similar_imgs(imgs):
    exp = st.session_state['explorer']
    similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
    paths = similar.to_pydict()['im_file']
    st.session_state['imgs'] = paths


def similarity_form(selected_imgs):
    st.write('Similarity Search')
    with st.form('similarity_form'):
        subcol1, subcol2 = st.columns([1, 1])
        with subcol1:
            limit = st.number_input('limit',
                                    min_value=None,
                                    max_value=None,
                                    value=25,
                                    label_visibility='collapsed',
                                    key='limit')

        with subcol2:
            disabled = not len(selected_imgs)
            st.write('Selected: ', len(selected_imgs))
            st.form_submit_button(
                'Search',
                disabled=disabled,
                on_click=find_similar_imgs,
                args=(selected_imgs, ),
            )
        if disabled:
            st.error('Select at least one image to search.')


# def persist_reset_form():
#    with st.form("persist_reset"):
#        col1, col2 = st.columns([1, 1])
#        with col1:
#            st.form_submit_button("Reset", on_click=reset)
#
#        with col2:
#            st.form_submit_button("Persist", on_click=update_state, args=("PERSISTING", True))


def run_sql_query():
    query = st.session_state.get('query')
    if query.rstrip().lstrip():
        exp = st.session_state['explorer']
        res = exp.sql_query(query, return_type='arrow')
        st.session_state['imgs'] = res.to_pydict()['im_file']


def reset_explorer():
    st.session_state['explorer'] = None
    st.session_state['imgs'] = None


def utralytics_explorer_docs_callback():
    with st.container(border=True):
        st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
                 width=100)
        st.markdown(
            "<p>This demo is built using Ultralytics Explorer API. Visit <a href=''>API docs</a> to try examples & learn more</p>",
            unsafe_allow_html=True,
            help=None)
        st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/')


def layout():
    st.set_page_config(layout='wide', initial_sidebar_state='collapsed')
    st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)

    if st.session_state.get('explorer') is None:
        init_explorer_form()
        return

    st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
    exp = st.session_state.get('explorer')
    col1, col2 = st.columns([0.75, 0.25], gap='small')

    imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
    total_imgs = len(imgs)
    with col1:
        subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
        with subcol1:
            st.write('Max Images Displayed:')
        with subcol2:
            num = st.number_input('Max Images Displayed',
                                  min_value=0,
                                  max_value=total_imgs,
                                  value=min(500, total_imgs),
                                  key='num_imgs_displayed',
                                  label_visibility='collapsed')
        with subcol3:
            st.write('Start Index:')
        with subcol4:
            start_idx = st.number_input('Start Index',
                                        min_value=0,
                                        max_value=total_imgs,
                                        value=0,
                                        key='start_index',
                                        label_visibility='collapsed')
        with subcol5:
            reset = st.button('Reset', use_container_width=False, key='reset')
            if reset:
                st.session_state['imgs'] = None
                st.experimental_rerun()

        query_form()
        if total_imgs:
            imgs_displayed = imgs[start_idx:start_idx + num]
            selected_imgs = image_select(
                f'Total samples: {total_imgs}',
                images=imgs_displayed,
                use_container_width=False,
                # indices=[i for i in range(num)] if select_all else None,
            )

    with col2:
        similarity_form(selected_imgs)
        # display_labels = st.checkbox("Labels", value=False, key="display_labels")
        utralytics_explorer_docs_callback()


if __name__ == '__main__':
    layout()