mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Add type hinting to explorer.py (#7388)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e19398a537
commit
d0562d7a2f
@ -1,11 +1,12 @@
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from pandas import DataFrame
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -13,18 +14,18 @@ from ultralytics.data.augment import Format
|
||||
from ultralytics.data.dataset import YOLODataset
|
||||
from ultralytics.data.utils import check_det_dataset
|
||||
from ultralytics.models.yolo.model import YOLO
|
||||
from ultralytics.utils import LOGGER, checks
|
||||
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks
|
||||
|
||||
from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch
|
||||
|
||||
|
||||
class ExplorerDataset(YOLODataset):
|
||||
|
||||
def __init__(self, *args, data=None, **kwargs):
|
||||
def __init__(self, *args, data: dict = None, **kwargs) -> None:
|
||||
super().__init__(*args, data=data, **kwargs)
|
||||
|
||||
# NOTE: Load the image directly without any resize operations.
|
||||
def load_image(self, i):
|
||||
def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
|
||||
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
|
||||
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
||||
if im is None: # not cached in RAM
|
||||
@ -39,7 +40,7 @@ class ExplorerDataset(YOLODataset):
|
||||
|
||||
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
def build_transforms(self, hyp: IterableSimpleNamespace = None):
|
||||
return Format(
|
||||
bbox_format='xyxy',
|
||||
normalize=False,
|
||||
@ -53,7 +54,10 @@ class ExplorerDataset(YOLODataset):
|
||||
|
||||
class Explorer:
|
||||
|
||||
def __init__(self, data='coco128.yaml', model='yolov8n.pt', uri='~/ultralytics/explorer') -> None:
|
||||
def __init__(self,
|
||||
data: Union[str, Path] = 'coco128.yaml',
|
||||
model: str = 'yolov8n.pt',
|
||||
uri: str = '~/ultralytics/explorer') -> None:
|
||||
checks.check_requirements(['lancedb', 'duckdb'])
|
||||
import lancedb
|
||||
|
||||
@ -68,7 +72,7 @@ class Explorer:
|
||||
self.table = None
|
||||
self.progress = 0
|
||||
|
||||
def create_embeddings_table(self, force=False, split='train'):
|
||||
def create_embeddings_table(self, force: bool = False, split: str = 'train') -> None:
|
||||
"""
|
||||
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
|
||||
already exists. Pass force=True to overwrite the existing table.
|
||||
@ -118,7 +122,7 @@ class Explorer:
|
||||
|
||||
self.table = table
|
||||
|
||||
def _yield_batches(self, dataset, data_info, model, exclude_keys: List):
|
||||
def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
|
||||
# Implement Batching
|
||||
for i in tqdm(range(len(dataset))):
|
||||
self.progress = float(i + 1) / len(dataset)
|
||||
@ -129,7 +133,9 @@ class Explorer:
|
||||
batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist()
|
||||
yield [batch]
|
||||
|
||||
def query(self, imgs=None, limit=25):
|
||||
def query(self,
|
||||
imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
||||
limit: int = 25) -> Any: # pyarrow.Table
|
||||
"""
|
||||
Query the table for similar images. Accepts a single image or a list of images.
|
||||
|
||||
@ -162,7 +168,9 @@ class Explorer:
|
||||
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
||||
return self.table.search(embeds).limit(limit).to_arrow()
|
||||
|
||||
def sql_query(self, query, return_type='pandas'):
|
||||
def sql_query(self,
|
||||
query: str,
|
||||
return_type: str = 'pandas') -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
|
||||
"""
|
||||
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
||||
|
||||
@ -177,7 +185,7 @@ class Explorer:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
query = 'SELECT * FROM table WHERE labels LIKE "%person%"'
|
||||
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
|
||||
result = exp.sql_query(query)
|
||||
```
|
||||
"""
|
||||
@ -201,7 +209,7 @@ class Explorer:
|
||||
elif return_type == 'arrow':
|
||||
return rs.arrow()
|
||||
|
||||
def plot_sql_query(self, query, labels=True):
|
||||
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
|
||||
"""
|
||||
Plot the results of a SQL-Like query on the table.
|
||||
Args:
|
||||
@ -215,7 +223,7 @@ class Explorer:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
query = 'SELECT * FROM table WHERE labels LIKE "%person%"'
|
||||
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
|
||||
result = exp.plot_sql_query(query)
|
||||
```
|
||||
"""
|
||||
@ -223,7 +231,11 @@ class Explorer:
|
||||
img = plot_similar_images(result, plot_labels=labels)
|
||||
return Image.fromarray(img)
|
||||
|
||||
def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'):
|
||||
def get_similar(self,
|
||||
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
||||
idx: Union[int, List[int]] = None,
|
||||
limit: int = 25,
|
||||
return_type: str = 'pandas') -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
|
||||
"""
|
||||
Query the table for similar images. Accepts a single image or a list of images.
|
||||
|
||||
@ -251,7 +263,11 @@ class Explorer:
|
||||
elif return_type == 'arrow':
|
||||
return similar
|
||||
|
||||
def plot_similar(self, img=None, idx=None, limit=25, labels=True):
|
||||
def plot_similar(self,
|
||||
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
||||
idx: Union[int, List[int]] = None,
|
||||
limit: int = 25,
|
||||
labels: bool = True) -> Image.Image:
|
||||
"""
|
||||
Plot the similar images. Accepts images or indexes.
|
||||
|
||||
@ -275,7 +291,7 @@ class Explorer:
|
||||
img = plot_similar_images(similar, plot_labels=labels)
|
||||
return Image.fromarray(img)
|
||||
|
||||
def similarity_index(self, max_dist=0.2, top_k=None, force=False):
|
||||
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
|
||||
"""
|
||||
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
|
||||
are max_dist or closer to the image in the embedding space at a given index.
|
||||
@ -329,7 +345,7 @@ class Explorer:
|
||||
self.sim_index = sim_table
|
||||
return sim_table.to_pandas()
|
||||
|
||||
def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
|
||||
def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
|
||||
"""
|
||||
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
|
||||
max_dist or closer to the image in the embedding space at a given index.
|
||||
@ -341,13 +357,16 @@ class Explorer:
|
||||
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
||||
|
||||
Returns:
|
||||
PIL Image containing the plot.
|
||||
PIL.PngImagePlugin.PngImageFile containing the plot.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
exp.plot_similarity_index()
|
||||
|
||||
similarity_idx_plot = exp.plot_similarity_index()
|
||||
similarity_idx_plot.show() # view image preview
|
||||
similarity_idx_plot.save('path/to/save/similarity_index_plot.png') # save contents to file
|
||||
```
|
||||
"""
|
||||
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
|
||||
@ -368,9 +387,10 @@ class Explorer:
|
||||
buffer.seek(0)
|
||||
|
||||
# Use Pillow to open the image from the buffer
|
||||
return Image.open(buffer)
|
||||
return Image.fromarray(np.array(Image.open(buffer)))
|
||||
|
||||
def _check_imgs_or_idxs(self, img, idx):
|
||||
def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None],
|
||||
idx: Union[None, int, List[int]]) -> List[np.ndarray]:
|
||||
if img is None and idx is None:
|
||||
raise ValueError('Either img or idx must be provided.')
|
||||
if img is not None and idx is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user