Add type hinting to explorer.py (#7388)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Burhan 2024-01-08 12:57:53 -05:00 committed by GitHub
parent e19398a537
commit d0562d7a2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,11 +1,12 @@
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import List from typing import Any, List, Tuple, Union
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from pandas import DataFrame
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
@ -13,18 +14,18 @@ from ultralytics.data.augment import Format
from ultralytics.data.dataset import YOLODataset from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_det_dataset from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.model import YOLO from ultralytics.models.yolo.model import YOLO
from ultralytics.utils import LOGGER, checks from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks
from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch
class ExplorerDataset(YOLODataset): class ExplorerDataset(YOLODataset):
def __init__(self, *args, data=None, **kwargs): def __init__(self, *args, data: dict = None, **kwargs) -> None:
super().__init__(*args, data=data, **kwargs) super().__init__(*args, data=data, **kwargs)
# NOTE: Load the image directly without any resize operations. # NOTE: Load the image directly without any resize operations.
def load_image(self, i): def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
"""Loads 1 image from dataset index 'i', returns (im, resized hw).""" """Loads 1 image from dataset index 'i', returns (im, resized hw)."""
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
if im is None: # not cached in RAM if im is None: # not cached in RAM
@ -39,7 +40,7 @@ class ExplorerDataset(YOLODataset):
return self.ims[i], self.im_hw0[i], self.im_hw[i] return self.ims[i], self.im_hw0[i], self.im_hw[i]
def build_transforms(self, hyp=None): def build_transforms(self, hyp: IterableSimpleNamespace = None):
return Format( return Format(
bbox_format='xyxy', bbox_format='xyxy',
normalize=False, normalize=False,
@ -53,7 +54,10 @@ class ExplorerDataset(YOLODataset):
class Explorer: class Explorer:
def __init__(self, data='coco128.yaml', model='yolov8n.pt', uri='~/ultralytics/explorer') -> None: def __init__(self,
data: Union[str, Path] = 'coco128.yaml',
model: str = 'yolov8n.pt',
uri: str = '~/ultralytics/explorer') -> None:
checks.check_requirements(['lancedb', 'duckdb']) checks.check_requirements(['lancedb', 'duckdb'])
import lancedb import lancedb
@ -68,7 +72,7 @@ class Explorer:
self.table = None self.table = None
self.progress = 0 self.progress = 0
def create_embeddings_table(self, force=False, split='train'): def create_embeddings_table(self, force: bool = False, split: str = 'train') -> None:
""" """
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
already exists. Pass force=True to overwrite the existing table. already exists. Pass force=True to overwrite the existing table.
@ -118,7 +122,7 @@ class Explorer:
self.table = table self.table = table
def _yield_batches(self, dataset, data_info, model, exclude_keys: List): def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
# Implement Batching # Implement Batching
for i in tqdm(range(len(dataset))): for i in tqdm(range(len(dataset))):
self.progress = float(i + 1) / len(dataset) self.progress = float(i + 1) / len(dataset)
@ -129,7 +133,9 @@ class Explorer:
batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist() batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist()
yield [batch] yield [batch]
def query(self, imgs=None, limit=25): def query(self,
imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
limit: int = 25) -> Any: # pyarrow.Table
""" """
Query the table for similar images. Accepts a single image or a list of images. Query the table for similar images. Accepts a single image or a list of images.
@ -162,7 +168,9 @@ class Explorer:
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy() embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
return self.table.search(embeds).limit(limit).to_arrow() return self.table.search(embeds).limit(limit).to_arrow()
def sql_query(self, query, return_type='pandas'): def sql_query(self,
query: str,
return_type: str = 'pandas') -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
""" """
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown. Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
@ -177,7 +185,7 @@ class Explorer:
```python ```python
exp = Explorer() exp = Explorer()
exp.create_embeddings_table() exp.create_embeddings_table()
query = 'SELECT * FROM table WHERE labels LIKE "%person%"' query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
result = exp.sql_query(query) result = exp.sql_query(query)
``` ```
""" """
@ -201,7 +209,7 @@ class Explorer:
elif return_type == 'arrow': elif return_type == 'arrow':
return rs.arrow() return rs.arrow()
def plot_sql_query(self, query, labels=True): def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
""" """
Plot the results of a SQL-Like query on the table. Plot the results of a SQL-Like query on the table.
Args: Args:
@ -215,7 +223,7 @@ class Explorer:
```python ```python
exp = Explorer() exp = Explorer()
exp.create_embeddings_table() exp.create_embeddings_table()
query = 'SELECT * FROM table WHERE labels LIKE "%person%"' query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
result = exp.plot_sql_query(query) result = exp.plot_sql_query(query)
``` ```
""" """
@ -223,7 +231,11 @@ class Explorer:
img = plot_similar_images(result, plot_labels=labels) img = plot_similar_images(result, plot_labels=labels)
return Image.fromarray(img) return Image.fromarray(img)
def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'): def get_similar(self,
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
idx: Union[int, List[int]] = None,
limit: int = 25,
return_type: str = 'pandas') -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
""" """
Query the table for similar images. Accepts a single image or a list of images. Query the table for similar images. Accepts a single image or a list of images.
@ -251,7 +263,11 @@ class Explorer:
elif return_type == 'arrow': elif return_type == 'arrow':
return similar return similar
def plot_similar(self, img=None, idx=None, limit=25, labels=True): def plot_similar(self,
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
idx: Union[int, List[int]] = None,
limit: int = 25,
labels: bool = True) -> Image.Image:
""" """
Plot the similar images. Accepts images or indexes. Plot the similar images. Accepts images or indexes.
@ -275,7 +291,7 @@ class Explorer:
img = plot_similar_images(similar, plot_labels=labels) img = plot_similar_images(similar, plot_labels=labels)
return Image.fromarray(img) return Image.fromarray(img)
def similarity_index(self, max_dist=0.2, top_k=None, force=False): def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
""" """
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
are max_dist or closer to the image in the embedding space at a given index. are max_dist or closer to the image in the embedding space at a given index.
@ -329,7 +345,7 @@ class Explorer:
self.sim_index = sim_table self.sim_index = sim_table
return sim_table.to_pandas() return sim_table.to_pandas()
def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False): def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
""" """
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
max_dist or closer to the image in the embedding space at a given index. max_dist or closer to the image in the embedding space at a given index.
@ -341,13 +357,16 @@ class Explorer:
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True. force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
Returns: Returns:
PIL Image containing the plot. PIL.PngImagePlugin.PngImageFile containing the plot.
Example: Example:
```python ```python
exp = Explorer() exp = Explorer()
exp.create_embeddings_table() exp.create_embeddings_table()
exp.plot_similarity_index()
similarity_idx_plot = exp.plot_similarity_index()
similarity_idx_plot.show() # view image preview
similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file
``` ```
""" """
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force) sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
@ -368,9 +387,10 @@ class Explorer:
buffer.seek(0) buffer.seek(0)
# Use Pillow to open the image from the buffer # Use Pillow to open the image from the buffer
return Image.open(buffer) return Image.fromarray(np.array(Image.open(buffer)))
def _check_imgs_or_idxs(self, img, idx): def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None],
idx: Union[None, int, List[int]]) -> List[np.ndarray]:
if img is None and idx is None: if img is None and idx is None:
raise ValueError('Either img or idx must be provided.') raise ValueError('Either img or idx must be provided.')
if img is not None and idx is not None: if img is not None and idx is not None: