Glenn Jocher 783033fa6b
ultralytics 8.0.238 Explorer Ask AI feature and fixes (#7408)
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
Co-authored-by: uwer <uwe.rosebrock@gmail.com>
Co-authored-by: Uwe Rosebrock <ro260@csiro.au>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing-q <1182102784@qq.com>
Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com>
Co-authored-by: AdamP <adamp87hun@gmail.com>
2024-01-08 23:36:29 +01:00

169 lines
6.8 KiB
Python

import getpass
from typing import List
import cv2
import numpy as np
import pandas as pd
from ultralytics.data.augment import LetterBox
from ultralytics.utils import LOGGER as logger
from ultralytics.utils import SETTINGS
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.ops import xyxy2xywh
from ultralytics.utils.plotting import plot_images
def get_table_schema(vector_size):
from lancedb.pydantic import LanceModel, Vector
class Schema(LanceModel):
im_file: str
labels: List[str]
cls: List[int]
bboxes: List[List[float]]
masks: List[List[List[int]]]
keypoints: List[List[List[float]]]
vector: Vector(vector_size)
return Schema
def get_sim_index_schema():
from lancedb.pydantic import LanceModel
class Schema(LanceModel):
idx: int
im_file: str
count: int
sim_im_files: List[str]
return Schema
def sanitize_batch(batch, dataset_info):
batch['cls'] = batch['cls'].flatten().int().tolist()
box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1])
batch['bboxes'] = [box for box, _ in box_cls_pair]
batch['cls'] = [cls for _, cls in box_cls_pair]
batch['labels'] = [dataset_info['names'][i] for i in batch['cls']]
batch['masks'] = batch['masks'].tolist() if 'masks' in batch else [[[]]]
batch['keypoints'] = batch['keypoints'].tolist() if 'keypoints' in batch else [[[]]]
return batch
def plot_query_result(similar_set, plot_labels=True):
"""
Plot images from the similar set.
Args:
similar_set (list): Pyarrow or pandas object containing the similar data points
plot_labels (bool): Whether to plot labels or not
"""
similar_set = similar_set.to_dict(
orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
empty_masks = [[[]]]
empty_boxes = [[]]
images = similar_set.get('im_file', [])
bboxes = similar_set.get('bboxes', []) if similar_set.get('bboxes') is not empty_boxes else []
masks = similar_set.get('masks') if similar_set.get('masks')[0] != empty_masks else []
kpts = similar_set.get('keypoints') if similar_set.get('keypoints')[0] != empty_masks else []
cls = similar_set.get('cls', [])
plot_size = 640
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
for i, imf in enumerate(images):
im = cv2.imread(imf)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
h, w = im.shape[:2]
r = min(plot_size / h, plot_size / w)
imgs.append(LetterBox(plot_size, center=False)(image=im).transpose(2, 0, 1))
if plot_labels:
if len(bboxes) > i and len(bboxes[i]) > 0:
box = np.array(bboxes[i], dtype=np.float32)
box[:, [0, 2]] *= r
box[:, [1, 3]] *= r
plot_boxes.append(box)
if len(masks) > i and len(masks[i]) > 0:
mask = np.array(masks[i], dtype=np.uint8)[0]
plot_masks.append(LetterBox(plot_size, center=False)(image=mask))
if len(kpts) > i and kpts[i] is not None:
kpt = np.array(kpts[i], dtype=np.float32)
kpt[:, :, :2] *= r
plot_kpts.append(kpt)
batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i)
imgs = np.stack(imgs, axis=0)
masks = np.stack(plot_masks, axis=0) if len(plot_masks) > 0 else np.zeros(0, dtype=np.uint8)
kpts = np.concatenate(plot_kpts, axis=0) if len(plot_kpts) > 0 else np.zeros((0, 51), dtype=np.float32)
boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if len(plot_boxes) > 0 else np.zeros(0, dtype=np.float32)
batch_idx = np.concatenate(batch_idx, axis=0)
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
return plot_images(imgs,
batch_idx,
cls,
bboxes=boxes,
masks=masks,
kpts=kpts,
max_subplots=len(images),
save=False,
threaded=False)
def prompt_sql_query(query):
check_requirements('openai>=1.6.1')
from openai import OpenAI
if not SETTINGS['openai_api_key']:
logger.warning('OpenAI API key not found in settings. Please enter your API key below.')
openai_api_key = getpass.getpass('OpenAI API key: ')
SETTINGS.update({'openai_api_key': openai_api_key})
openai = OpenAI(api_key=SETTINGS['openai_api_key'])
messages = [
{
'role':
'system',
'content':
'''
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
the following schema and a user request. You only need to output the format with fixed selection
statement that selects everything from "'table'", like `SELECT * from 'table'`
Schema:
im_file: string not null
labels: list<item: string> not null
child 0, item: string
cls: list<item: int64> not null
child 0, item: int64
bboxes: list<item: list<item: double>> not null
child 0, item: list<item: double>
child 0, item: double
masks: list<item: list<item: list<item: int64>>> not null
child 0, item: list<item: list<item: int64>>
child 0, item: list<item: int64>
child 0, item: int64
keypoints: list<item: list<item: list<item: double>>> not null
child 0, item: list<item: list<item: double>>
child 0, item: list<item: double>
child 0, item: double
vector: fixed_size_list<item: float>[256] not null
child 0, item: float
Some details about the schema:
- the "labels" column contains the string values like 'person' and 'dog' for the respective objects
in each image
- the "cls" column contains the integer values on these classes that map them the labels
Example of a correct query:
request - Get all data points that contain 2 or more people and at least one dog
correct query-
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
'''},
{
'role': 'user',
'content': f'{query}'}, ]
response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages)
return response.choices[0].message.content