Ayush Chaurasia aca8eb1fd4
ultralytics 8.0.236 dataset semantic & SQL search API (#7136)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing-q <1182102784@qq.com>
2024-01-06 19:53:25 +01:00

404 lines
15 KiB
Python

from io import BytesIO
from pathlib import Path
from typing import List
import cv2
import numpy as np
import torch
from matplotlib import pyplot as plt
from PIL import Image
from tqdm import tqdm
from ultralytics.data.augment import Format
from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.model import YOLO
from ultralytics.utils import LOGGER, checks
from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch
class ExplorerDataset(YOLODataset):
def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, data=data, **kwargs)
# NOTE: Load the image directly without any resize operations.
def load_image(self, i):
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
if im is None: # not cached in RAM
if fn.exists(): # load npy
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
if im is None:
raise FileNotFoundError(f'Image Not Found {f}')
h0, w0 = im.shape[:2] # orig hw
return im, (h0, w0), im.shape[:2]
return self.ims[i], self.im_hw0[i], self.im_hw[i]
def build_transforms(self, hyp=None):
transforms = Format(
bbox_format='xyxy',
normalize=False,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask,
)
return transforms
class Explorer:
def __init__(self, data='coco128.yaml', model='yolov8n.pt', uri='~/ultralytics/explorer') -> None:
checks.check_requirements(['lancedb', 'duckdb'])
import lancedb
self.connection = lancedb.connect(uri)
self.table_name = Path(data).name.lower() + '_' + model.lower()
self.sim_idx_base_name = f'{self.table_name}_sim_idx'.lower(
) # Use this name and append thres and top_k to reuse the table
self.model = YOLO(model)
self.data = data # None
self.choice_set = None
self.table = None
self.progress = 0
def create_embeddings_table(self, force=False, split='train'):
"""
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
already exists. Pass force=True to overwrite the existing table.
Args:
force (bool): Whether to overwrite the existing table or not. Defaults to False.
split (str): Split of the dataset to use. Defaults to 'train'.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
```
"""
if self.table is not None and not force:
LOGGER.info('Table already exists. Reusing it. Pass force=True to overwrite it.')
return
if self.table_name in self.connection.table_names() and not force:
LOGGER.info(f'Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.')
self.table = self.connection.open_table(self.table_name)
self.progress = 1
return
if self.data is None:
raise ValueError('Data must be provided to create embeddings table')
data_info = check_det_dataset(self.data)
if split not in data_info:
raise ValueError(
f'Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}'
)
choice_set = data_info[split]
choice_set = choice_set if isinstance(choice_set, list) else [choice_set]
self.choice_set = choice_set
dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task)
# Create the table schema
batch = dataset[0]
vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
Schema = get_table_schema(vector_size)
table = self.connection.create_table(self.table_name, schema=Schema, mode='overwrite')
table.add(
self._yield_batches(dataset,
data_info,
self.model,
exclude_keys=['img', 'ratio_pad', 'resized_shape', 'ori_shape', 'batch_idx']))
self.table = table
def _yield_batches(self, dataset, data_info, model, exclude_keys: List):
# Implement Batching
for i in tqdm(range(len(dataset))):
self.progress = float(i + 1) / len(dataset)
batch = dataset[i]
for k in exclude_keys:
batch.pop(k, None)
batch = sanitize_batch(batch, data_info)
batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist()
yield [batch]
def query(self, imgs=None, limit=25):
"""
Query the table for similar images. Accepts a single image or a list of images.
Args:
imgs (str or list): Path to the image or a list of paths to the images.
limit (int): Number of results to return.
Returns:
An arrow table containing the results. Supports converting to:
- pandas dataframe: `result.to_pandas()`
- dict of lists: `result.to_pydict()`
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
similar = exp.query(img='https://ultralytics.com/images/zidane.jpg')
```
"""
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
if isinstance(imgs, str):
imgs = [imgs]
elif isinstance(imgs, list):
pass
else:
raise ValueError(f'img must be a string or a list of strings. Got {type(imgs)}')
embeds = self.model.embed(imgs)
# Get avg if multiple images are passed (len > 1)
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
query = self.table.search(embeds).limit(limit).to_arrow()
return query
def sql_query(self, query, return_type='pandas'):
"""
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
Args:
query (str): SQL query to run.
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
Returns:
An arrow table containing the results.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
query = 'SELECT * FROM table WHERE labels LIKE "%person%"'
result = exp.sql_query(query)
```
"""
import duckdb
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
table = self.table.to_arrow() # noqa
if not query.startswith('SELECT') and not query.startswith('WHERE'):
raise ValueError(
'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause.')
if query.startswith('WHERE'):
query = f"SELECT * FROM 'table' {query}"
LOGGER.info(f'Running query: {query}')
rs = duckdb.sql(query)
if return_type == 'pandas':
return rs.df()
elif return_type == 'arrow':
return rs.arrow()
def plot_sql_query(self, query, labels=True):
"""
Plot the results of a SQL-Like query on the table.
Args:
query (str): SQL query to run.
labels (bool): Whether to plot the labels or not.
Returns:
PIL Image containing the plot.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
query = 'SELECT * FROM table WHERE labels LIKE "%person%"'
result = exp.plot_sql_query(query)
```
"""
result = self.sql_query(query, return_type='arrow')
img = plot_similar_images(result, plot_labels=labels)
img = Image.fromarray(img)
return img
def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'):
"""
Query the table for similar images. Accepts a single image or a list of images.
Args:
img (str or list): Path to the image or a list of paths to the images.
idx (int or list): Index of the image in the table or a list of indexes.
limit (int): Number of results to return. Defaults to 25.
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
Returns:
A table or pandas dataframe containing the results.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
img = self._check_imgs_or_idxs(img, idx)
similar = self.query(img, limit=limit)
if return_type == 'pandas':
return similar.to_pandas()
elif return_type == 'arrow':
return similar
def plot_similar(self, img=None, idx=None, limit=25, labels=True):
"""
Plot the similar images. Accepts images or indexes.
Args:
img (str or list): Path to the image or a list of paths to the images.
idx (int or list): Index of the image in the table or a list of indexes.
labels (bool): Whether to plot the labels or not.
limit (int): Number of results to return. Defaults to 25.
Returns:
PIL Image containing the plot.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
similar = self.get_similar(img, idx, limit, return_type='arrow')
img = plot_similar_images(similar, plot_labels=labels)
img = Image.fromarray(img)
return img
def similarity_index(self, max_dist=0.2, top_k=None, force=False):
"""
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
are max_dist or closer to the image in the embedding space at a given index.
Args:
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
vector search. Defaults to 0.01.
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
Returns:
A pandas dataframe containing the similarity index.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
sim_idx = exp.similarity_index()
```
"""
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
sim_idx_table_name = f'{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}'.lower()
if sim_idx_table_name in self.connection.table_names() and not force:
LOGGER.info('Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.')
return self.connection.open_table(sim_idx_table_name).to_pandas()
if top_k and not (1.0 >= top_k >= 0.0):
raise ValueError(f'top_k must be between 0.0 and 1.0. Got {top_k}')
if max_dist < 0.0:
raise ValueError(f'max_dist must be greater than 0. Got {max_dist}')
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
top_k = max(top_k, 1)
features = self.table.to_lance().to_table(columns=['vector', 'im_file']).to_pydict()
im_files = features['im_file']
embeddings = features['vector']
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite')
def _yield_sim_idx():
for i in tqdm(range(len(embeddings))):
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}')
yield [{
'idx': i,
'im_file': im_files[i],
'count': len(sim_idx),
'sim_im_files': sim_idx['im_file'].tolist()}]
sim_table.add(_yield_sim_idx())
self.sim_index = sim_table
return sim_table.to_pandas()
def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
"""
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
max_dist or closer to the image in the embedding space at a given index.
Args:
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when
running vector search. Defaults to 0.01.
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
Returns:
PIL Image containing the plot.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
exp.plot_similarity_index()
```
"""
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
sim_count = sim_idx['count'].tolist()
sim_count = np.array(sim_count)
indices = np.arange(len(sim_count))
# Create the bar plot
plt.bar(indices, sim_count)
# Customize the plot (optional)
plt.xlabel('data idx')
plt.ylabel('Count')
plt.title('Similarity Count')
buffer = BytesIO()
plt.savefig(buffer, format='png')
buffer.seek(0)
# Use Pillow to open the image from the buffer
image = Image.open(buffer)
return image
def _check_imgs_or_idxs(self, img, idx):
if img is None and idx is None:
raise ValueError('Either img or idx must be provided.')
if img is not None and idx is not None:
raise ValueError('Only one of img or idx must be provided.')
if idx is not None:
idx = idx if isinstance(idx, list) else [idx]
img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
img = img if isinstance(img, list) else [img]
return img
def visualize(self, result):
"""
Visualize the results of a query.
Args:
result (arrow table): Arrow table containing the results of a query.
"""
# TODO:
pass
def generate_report(self, result):
"""Generate a report of the dataset."""
pass