mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00

Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com> Co-authored-by: uwer <uwe.rosebrock@gmail.com> Co-authored-by: Uwe Rosebrock <ro260@csiro.au> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1182102784@qq.com> Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com> Co-authored-by: AdamP <adamp87hun@gmail.com>
169 lines
6.8 KiB
Python
169 lines
6.8 KiB
Python
import getpass
|
|
from typing import List
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
from ultralytics.data.augment import LetterBox
|
|
from ultralytics.utils import LOGGER as logger
|
|
from ultralytics.utils import SETTINGS
|
|
from ultralytics.utils.checks import check_requirements
|
|
from ultralytics.utils.ops import xyxy2xywh
|
|
from ultralytics.utils.plotting import plot_images
|
|
|
|
|
|
def get_table_schema(vector_size):
|
|
from lancedb.pydantic import LanceModel, Vector
|
|
|
|
class Schema(LanceModel):
|
|
im_file: str
|
|
labels: List[str]
|
|
cls: List[int]
|
|
bboxes: List[List[float]]
|
|
masks: List[List[List[int]]]
|
|
keypoints: List[List[List[float]]]
|
|
vector: Vector(vector_size)
|
|
|
|
return Schema
|
|
|
|
|
|
def get_sim_index_schema():
|
|
from lancedb.pydantic import LanceModel
|
|
|
|
class Schema(LanceModel):
|
|
idx: int
|
|
im_file: str
|
|
count: int
|
|
sim_im_files: List[str]
|
|
|
|
return Schema
|
|
|
|
|
|
def sanitize_batch(batch, dataset_info):
|
|
batch['cls'] = batch['cls'].flatten().int().tolist()
|
|
box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1])
|
|
batch['bboxes'] = [box for box, _ in box_cls_pair]
|
|
batch['cls'] = [cls for _, cls in box_cls_pair]
|
|
batch['labels'] = [dataset_info['names'][i] for i in batch['cls']]
|
|
batch['masks'] = batch['masks'].tolist() if 'masks' in batch else [[[]]]
|
|
batch['keypoints'] = batch['keypoints'].tolist() if 'keypoints' in batch else [[[]]]
|
|
|
|
return batch
|
|
|
|
|
|
def plot_query_result(similar_set, plot_labels=True):
|
|
"""
|
|
Plot images from the similar set.
|
|
|
|
Args:
|
|
similar_set (list): Pyarrow or pandas object containing the similar data points
|
|
plot_labels (bool): Whether to plot labels or not
|
|
"""
|
|
similar_set = similar_set.to_dict(
|
|
orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
|
|
empty_masks = [[[]]]
|
|
empty_boxes = [[]]
|
|
images = similar_set.get('im_file', [])
|
|
bboxes = similar_set.get('bboxes', []) if similar_set.get('bboxes') is not empty_boxes else []
|
|
masks = similar_set.get('masks') if similar_set.get('masks')[0] != empty_masks else []
|
|
kpts = similar_set.get('keypoints') if similar_set.get('keypoints')[0] != empty_masks else []
|
|
cls = similar_set.get('cls', [])
|
|
|
|
plot_size = 640
|
|
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
|
|
for i, imf in enumerate(images):
|
|
im = cv2.imread(imf)
|
|
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
|
h, w = im.shape[:2]
|
|
r = min(plot_size / h, plot_size / w)
|
|
imgs.append(LetterBox(plot_size, center=False)(image=im).transpose(2, 0, 1))
|
|
if plot_labels:
|
|
if len(bboxes) > i and len(bboxes[i]) > 0:
|
|
box = np.array(bboxes[i], dtype=np.float32)
|
|
box[:, [0, 2]] *= r
|
|
box[:, [1, 3]] *= r
|
|
plot_boxes.append(box)
|
|
if len(masks) > i and len(masks[i]) > 0:
|
|
mask = np.array(masks[i], dtype=np.uint8)[0]
|
|
plot_masks.append(LetterBox(plot_size, center=False)(image=mask))
|
|
if len(kpts) > i and kpts[i] is not None:
|
|
kpt = np.array(kpts[i], dtype=np.float32)
|
|
kpt[:, :, :2] *= r
|
|
plot_kpts.append(kpt)
|
|
batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i)
|
|
imgs = np.stack(imgs, axis=0)
|
|
masks = np.stack(plot_masks, axis=0) if len(plot_masks) > 0 else np.zeros(0, dtype=np.uint8)
|
|
kpts = np.concatenate(plot_kpts, axis=0) if len(plot_kpts) > 0 else np.zeros((0, 51), dtype=np.float32)
|
|
boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if len(plot_boxes) > 0 else np.zeros(0, dtype=np.float32)
|
|
batch_idx = np.concatenate(batch_idx, axis=0)
|
|
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
|
|
|
|
return plot_images(imgs,
|
|
batch_idx,
|
|
cls,
|
|
bboxes=boxes,
|
|
masks=masks,
|
|
kpts=kpts,
|
|
max_subplots=len(images),
|
|
save=False,
|
|
threaded=False)
|
|
|
|
|
|
def prompt_sql_query(query):
|
|
check_requirements('openai>=1.6.1')
|
|
from openai import OpenAI
|
|
|
|
if not SETTINGS['openai_api_key']:
|
|
logger.warning('OpenAI API key not found in settings. Please enter your API key below.')
|
|
openai_api_key = getpass.getpass('OpenAI API key: ')
|
|
SETTINGS.update({'openai_api_key': openai_api_key})
|
|
openai = OpenAI(api_key=SETTINGS['openai_api_key'])
|
|
|
|
messages = [
|
|
{
|
|
'role':
|
|
'system',
|
|
'content':
|
|
'''
|
|
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
|
|
the following schema and a user request. You only need to output the format with fixed selection
|
|
statement that selects everything from "'table'", like `SELECT * from 'table'`
|
|
|
|
Schema:
|
|
im_file: string not null
|
|
labels: list<item: string> not null
|
|
child 0, item: string
|
|
cls: list<item: int64> not null
|
|
child 0, item: int64
|
|
bboxes: list<item: list<item: double>> not null
|
|
child 0, item: list<item: double>
|
|
child 0, item: double
|
|
masks: list<item: list<item: list<item: int64>>> not null
|
|
child 0, item: list<item: list<item: int64>>
|
|
child 0, item: list<item: int64>
|
|
child 0, item: int64
|
|
keypoints: list<item: list<item: list<item: double>>> not null
|
|
child 0, item: list<item: list<item: double>>
|
|
child 0, item: list<item: double>
|
|
child 0, item: double
|
|
vector: fixed_size_list<item: float>[256] not null
|
|
child 0, item: float
|
|
|
|
Some details about the schema:
|
|
- the "labels" column contains the string values like 'person' and 'dog' for the respective objects
|
|
in each image
|
|
- the "cls" column contains the integer values on these classes that map them the labels
|
|
|
|
Example of a correct query:
|
|
request - Get all data points that contain 2 or more people and at least one dog
|
|
correct query-
|
|
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
|
|
'''},
|
|
{
|
|
'role': 'user',
|
|
'content': f'{query}'}, ]
|
|
|
|
response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages)
|
|
return response.choices[0].message.content
|