mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
ultralytics 8.0.52
reduced TAL CUDA usage and AMP check fix (#1333)
Co-authored-by: CNH5 <74132034+CNH5@users.noreply.github.com> Co-authored-by: Huijae Lee <46982469+ZeroAct@users.noreply.github.com> Co-authored-by: Lorenzo Mammana <lorenzom96@hotmail.it> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hardik Dava <39372750+hardikdava@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
790f9c067c
commit
177a68b39f
10
.github/workflows/publish.yml
vendored
10
.github/workflows/publish.yml
vendored
@ -4,6 +4,8 @@
|
||||
name: Publish to PyPI and Deploy Docs
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
pypi:
|
||||
@ -12,8 +14,6 @@ on:
|
||||
docs:
|
||||
type: boolean
|
||||
description: Deploy Docs
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
@ -35,9 +35,9 @@ jobs:
|
||||
- name: Check PyPI version
|
||||
shell: python
|
||||
run: |
|
||||
import os
|
||||
import pkg_resources as pkg
|
||||
import ultralytics
|
||||
import os
|
||||
from ultralytics.yolo.utils.checks import check_latest_pypi_version
|
||||
|
||||
v_local = pkg.parse_version(ultralytics.__version__).release
|
||||
@ -45,7 +45,7 @@ jobs:
|
||||
print(f'Local version is {v_local}')
|
||||
print(f'PyPI version is {v_pypi}')
|
||||
d = [a - b for a, b in zip(v_local, v_pypi)] # diff
|
||||
increment = (d[0] == d[1] == 0) and d[2] == 1 # only patch increment by 1
|
||||
increment = (d[0] == d[1] == 0) and d[2] == 1 # only publish if patch version increments by 1
|
||||
os.system(f'echo "increment={increment}" >> $GITHUB_OUTPUT')
|
||||
if increment:
|
||||
print('Local version is higher than PyPI version. Publishing new version to PyPI ✅.')
|
||||
@ -64,4 +64,4 @@ jobs:
|
||||
run: |
|
||||
mkdocs gh-deploy || true
|
||||
git checkout gh-pages
|
||||
git push https://github.com/ultralytics/docs gh-pages --force
|
||||
git push https://${{ secrets.PERSONAL_ACCESS_TOKEN }}@github.com/ultralytics/docs gh-pages --force
|
||||
|
@ -110,7 +110,6 @@ task.
|
||||
| mask_ratio | 4 | mask downsample ratio (segment train only) |
|
||||
| dropout | 0.0 | use dropout regularization (classify train only) |
|
||||
| val | True | validate/test during training |
|
||||
| min_memory | False | minimize memory footprint loss function, choices=[False, True, <roll_out_thr>] |
|
||||
|
||||
### Prediction
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
__version__ = '8.0.51'
|
||||
__version__ = '8.0.52'
|
||||
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils.checks import check_yolo as checks
|
||||
|
@ -5,13 +5,9 @@ import requests
|
||||
from ultralytics.hub.auth import Auth
|
||||
from ultralytics.hub.session import HUBTrainingSession
|
||||
from ultralytics.hub.utils import PREFIX, split_key
|
||||
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_LIST
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils import LOGGER, emojis
|
||||
|
||||
# Define all export formats
|
||||
EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ['ultralytics_tflite', 'ultralytics_coreml']
|
||||
|
||||
|
||||
def start(key=''):
|
||||
"""
|
||||
@ -63,9 +59,15 @@ def reset_model(key=''):
|
||||
LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
|
||||
|
||||
|
||||
def export_fmts_hub():
|
||||
# Returns a list of HUB-supported export formats
|
||||
from ultralytics.yolo.engine.exporter import export_formats
|
||||
return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
|
||||
|
||||
|
||||
def export_model(key='', format='torchscript'):
|
||||
# Export a model to all formats
|
||||
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
|
||||
assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
|
||||
api_key, model_id = split_key(key)
|
||||
r = requests.post('https://api.ultralytics.com/export',
|
||||
json={
|
||||
@ -78,7 +80,7 @@ def export_model(key='', format='torchscript'):
|
||||
|
||||
def get_export(key='', format='torchscript'):
|
||||
# Get an exported model dictionary with download URL
|
||||
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
|
||||
assert format in export_fmts_hub, f"Unsupported export format '{format}', valid formats are {export_fmts_hub}"
|
||||
api_key, model_id = split_key(key)
|
||||
r = requests.post('https://api.ultralytics.com/get-export',
|
||||
json={
|
||||
|
@ -184,7 +184,7 @@ class AutoBackend(nn.Module):
|
||||
LOGGER.info(f'Loading {w} for CoreML inference...')
|
||||
import coremltools as ct
|
||||
model = ct.models.MLModel(w)
|
||||
metadata = model.user_defined_metadata
|
||||
metadata = dict(model.user_defined_metadata)
|
||||
elif saved_model: # TF SavedModel
|
||||
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
|
||||
import tensorflow as tf
|
||||
@ -256,10 +256,10 @@ class AutoBackend(nn.Module):
|
||||
nhwc = model.runtime.startswith("tensorflow")
|
||||
'''
|
||||
else:
|
||||
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_TABLE
|
||||
from ultralytics.yolo.engine.exporter import export_formats
|
||||
raise TypeError(f"model='{w}' is not a supported model format. "
|
||||
'See https://docs.ultralytics.com/tasks/detection/#export for help.'
|
||||
f'\n\n{EXPORT_FORMATS_TABLE}')
|
||||
f'\n\n{export_formats()}')
|
||||
|
||||
# Load external metadata YAML
|
||||
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
|
||||
|
@ -8,7 +8,6 @@ from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import requests
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -204,12 +203,13 @@ class Detections:
|
||||
|
||||
def pandas(self):
|
||||
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
|
||||
import pandas
|
||||
new = copy(self) # return copy
|
||||
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
|
||||
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
|
||||
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
|
||||
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
|
||||
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
|
||||
setattr(new, k, [pandas.DataFrame(x, columns=c) for x in a])
|
||||
return new
|
||||
|
||||
def tolist(self):
|
||||
|
@ -122,7 +122,7 @@ class BaseModel(nn.Module):
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
||||
|
||||
def info(self, verbose=False, imgsz=640):
|
||||
def info(self, verbose=True, imgsz=640):
|
||||
"""
|
||||
Prints model information
|
||||
|
||||
|
@ -36,18 +36,26 @@ def _indices_to_matches(cost_matrix, indices, thresh):
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix, thresh):
|
||||
def linear_assignment(cost_matrix, thresh, use_lap=True):
|
||||
# Linear assignment implementations with scipy and lap.lapjv
|
||||
if cost_matrix.size == 0:
|
||||
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
||||
matches, unmatched_a, unmatched_b = [], [], []
|
||||
|
||||
# TODO: investigate scipy.optimize.linear_sum_assignment() for lap.lapjv()
|
||||
cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
||||
|
||||
matches.extend([ix, mx] for ix, mx in enumerate(x) if mx >= 0)
|
||||
if use_lap:
|
||||
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
||||
matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
|
||||
unmatched_a = np.where(x < 0)[0]
|
||||
unmatched_b = np.where(y < 0)[0]
|
||||
matches = np.asarray(matches)
|
||||
else:
|
||||
# Scipy linear sum assignment is NOT working correctly, DO NOT USE
|
||||
y, x = scipy.optimize.linear_sum_assignment(cost_matrix) # row y, col x
|
||||
matches = np.asarray([[i, x] for i, x in enumerate(x) if cost_matrix[i, x] <= thresh])
|
||||
unmatched = np.ones(cost_matrix.shape)
|
||||
for i, xi in matches:
|
||||
unmatched[i, xi] = 0.0
|
||||
unmatched_a = np.where(unmatched.all(1))[0]
|
||||
unmatched_b = np.where(unmatched.all(0))[0]
|
||||
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
|
@ -17,7 +17,7 @@ cache: False # True/ram, disk or False. Use cache for data loading
|
||||
device: # device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
|
||||
workers: 8 # number of worker threads for data loading (per RANK if DDP)
|
||||
project: # project name
|
||||
name: # experiment name
|
||||
name: # experiment name, results saved to 'project/name' directory
|
||||
exist_ok: False # whether to overwrite existing experiment
|
||||
pretrained: False # whether to use a pretrained model
|
||||
optimizer: SGD # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
|
||||
@ -30,7 +30,6 @@ rect: False # support rectangular training if mode='train', support rectangular
|
||||
cos_lr: False # use cosine learning rate scheduler
|
||||
close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
|
||||
resume: False # resume training from last checkpoint
|
||||
min_memory: False # minimize memory footprint loss function, choices=[False, True, <roll_out_thr>]
|
||||
# Segmentation
|
||||
overlap_mask: True # masks should overlap during training (segment train only)
|
||||
mask_ratio: 4 # mask downsample ratio (segment train only)
|
||||
|
@ -23,8 +23,8 @@ from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file
|
||||
from ultralytics.yolo.utils.ops import segments2boxes
|
||||
|
||||
HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
||||
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # include image suffixes
|
||||
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
|
||||
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # image suffixes
|
||||
VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm' # video suffixes
|
||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
|
||||
|
@ -57,16 +57,13 @@ from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from ultralytics.nn.autobackend import check_class_names
|
||||
from ultralytics.nn.modules import C2f, Detect, Segment
|
||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
||||
from ultralytics.yolo.data.utils import IMAGENET_MEAN, IMAGENET_STD, check_det_dataset
|
||||
from ultralytics.yolo.data.utils import IMAGENET_MEAN, IMAGENET_STD
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, __version__, callbacks, colorstr,
|
||||
get_default_args, yaml_save)
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version
|
||||
@ -79,6 +76,7 @@ ARM64 = platform.machine() in ('arm64', 'aarch64')
|
||||
|
||||
def export_formats():
|
||||
# YOLOv8 export formats
|
||||
import pandas
|
||||
x = [
|
||||
['PyTorch', '-', '.pt', True, True],
|
||||
['TorchScript', 'torchscript', '.torchscript', True, True],
|
||||
@ -92,11 +90,7 @@ def export_formats():
|
||||
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
|
||||
['TensorFlow.js', 'tfjs', '_web_model', True, False],
|
||||
['PaddlePaddle', 'paddle', '_paddle_model', True, True], ]
|
||||
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
||||
|
||||
|
||||
EXPORT_FORMATS_LIST = list(export_formats()['Argument'][1:])
|
||||
EXPORT_FORMATS_TABLE = str(export_formats())
|
||||
return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
||||
|
||||
|
||||
def gd_outputs(gd):
|
||||
@ -537,7 +531,7 @@ class Exporter:
|
||||
# Export to TF
|
||||
int8 = '-oiqt -qt per-tensor' if self.args.int8 else ''
|
||||
cmd = f'onnx2tf -i {f_onnx} -o {f} -nuo --non_verbose {int8}'
|
||||
LOGGER.info(f'\n{prefix} running {cmd}')
|
||||
LOGGER.info(f"\n{prefix} running '{cmd}'")
|
||||
subprocess.run(cmd, shell=True)
|
||||
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
|
||||
@ -574,47 +568,47 @@ class Exporter:
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
|
||||
if self.args.int8:
|
||||
f = saved_model / (self.file.stem + '_integer_quant.tflite') # fp32 in/out
|
||||
f = saved_model / f'{self.file.stem}_integer_quant.tflite' # fp32 in/out
|
||||
elif self.args.half:
|
||||
f = saved_model / (self.file.stem + '_float16.tflite')
|
||||
f = saved_model / f'{self.file.stem}_float16.tflite'
|
||||
else:
|
||||
f = saved_model / (self.file.stem + '_float32.tflite')
|
||||
return str(f), None # noqa
|
||||
f = saved_model / f'{self.file.stem}_float32.tflite'
|
||||
return str(f), None
|
||||
|
||||
# OLD VERSION BELOW ---------------------------------------------------------------
|
||||
batch_size, ch, *imgsz = list(self.im.shape) # BCHW
|
||||
f = str(self.file).replace(self.file.suffix, '-fp16.tflite')
|
||||
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
|
||||
converter.target_spec.supported_types = [tf.float16]
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
if self.args.int8:
|
||||
|
||||
def representative_dataset_gen(dataset, n_images=100):
|
||||
# Dataset generator for use with converter.representative_dataset, returns a generator of np arrays
|
||||
for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
|
||||
im = np.transpose(img, [1, 2, 0])
|
||||
im = np.expand_dims(im, axis=0).astype(np.float32)
|
||||
im /= 255
|
||||
yield [im]
|
||||
if n >= n_images:
|
||||
break
|
||||
|
||||
dataset = LoadImages(check_det_dataset(self.args.data)['train'], imgsz=imgsz, auto=False)
|
||||
converter.representative_dataset = lambda: representative_dataset_gen(dataset, n_images=100)
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||
converter.target_spec.supported_types = []
|
||||
converter.inference_input_type = tf.uint8 # or tf.int8
|
||||
converter.inference_output_type = tf.uint8 # or tf.int8
|
||||
converter.experimental_new_quantizer = True
|
||||
f = str(self.file).replace(self.file.suffix, '-int8.tflite')
|
||||
if nms or agnostic_nms:
|
||||
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
|
||||
|
||||
tflite_model = converter.convert()
|
||||
open(f, 'wb').write(tflite_model)
|
||||
return f, None
|
||||
# # OLD TFLITE EXPORT CODE BELOW -------------------------------------------------------------------------------
|
||||
# batch_size, ch, *imgsz = list(self.im.shape) # BCHW
|
||||
# f = str(self.file).replace(self.file.suffix, '-fp16.tflite')
|
||||
#
|
||||
# converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
|
||||
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
|
||||
# converter.target_spec.supported_types = [tf.float16]
|
||||
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
# if self.args.int8:
|
||||
#
|
||||
# def representative_dataset_gen(dataset, n_images=100):
|
||||
# # Dataset generator for use with converter.representative_dataset, returns a generator of np arrays
|
||||
# for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
|
||||
# im = np.transpose(img, [1, 2, 0])
|
||||
# im = np.expand_dims(im, axis=0).astype(np.float32)
|
||||
# im /= 255
|
||||
# yield [im]
|
||||
# if n >= n_images:
|
||||
# break
|
||||
#
|
||||
# dataset = LoadImages(check_det_dataset(self.args.data)['train'], imgsz=imgsz, auto=False)
|
||||
# converter.representative_dataset = lambda: representative_dataset_gen(dataset, n_images=100)
|
||||
# converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||
# converter.target_spec.supported_types = []
|
||||
# converter.inference_input_type = tf.uint8 # or tf.int8
|
||||
# converter.inference_output_type = tf.uint8 # or tf.int8
|
||||
# converter.experimental_new_quantizer = True
|
||||
# f = str(self.file).replace(self.file.suffix, '-int8.tflite')
|
||||
# if nms or agnostic_nms:
|
||||
# converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
|
||||
#
|
||||
# tflite_model = converter.convert()
|
||||
# open(f, 'wb').write(tflite_model)
|
||||
# return f, None
|
||||
|
||||
@try_export
|
||||
def _export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
|
||||
@ -638,6 +632,7 @@ class Exporter:
|
||||
f = str(tflite_model).replace('.tflite', '_edgetpu.tflite') # Edge TPU model
|
||||
|
||||
cmd = f'edgetpu_compiler -s -d -k 10 --out_dir {Path(f).parent} {tflite_model}'
|
||||
LOGGER.info(f"{prefix} running '{cmd}'")
|
||||
subprocess.run(cmd.split(), check=True)
|
||||
self._add_tflite_metadata(f)
|
||||
return f, None
|
||||
|
@ -48,7 +48,7 @@ class Results:
|
||||
self.probs = probs if probs is not None else None
|
||||
self.names = names
|
||||
self.path = path
|
||||
self._keys = (k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None)
|
||||
self._keys = [k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None]
|
||||
|
||||
def pandas(self):
|
||||
pass
|
||||
@ -122,8 +122,7 @@ class Results:
|
||||
Returns:
|
||||
(None) or (PIL.Image): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned.
|
||||
"""
|
||||
img = deepcopy(self.orig_img)
|
||||
annotator = Annotator(img, line_width, font_size, font, pil, example)
|
||||
annotator = Annotator(deepcopy(self.orig_img), line_width, font_size, font, pil, example)
|
||||
boxes = self.boxes
|
||||
masks = self.masks
|
||||
logits = self.probs
|
||||
@ -136,7 +135,7 @@ class Results:
|
||||
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
|
||||
if masks is not None:
|
||||
im = torch.as_tensor(img, dtype=torch.float16, device=masks.data.device).permute(2, 0, 1).flip(0)
|
||||
im = torch.as_tensor(annotator.im, dtype=torch.float16, device=masks.data.device).permute(2, 0, 1).flip(0)
|
||||
im = F.resize(im.contiguous(), masks.data.shape[1:]) / 255
|
||||
annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im)
|
||||
|
||||
@ -146,7 +145,7 @@ class Results:
|
||||
text = f"{', '.join(f'{names[j] if names else j} {logits[j]:.2f}' for j in top5i)}, "
|
||||
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||
|
||||
return img
|
||||
return np.asarray(annotator.im) if annotator.pil else annotator.im
|
||||
|
||||
|
||||
class Boxes:
|
||||
|
@ -197,12 +197,15 @@ class BaseTrainer:
|
||||
"""
|
||||
Builds dataloaders and optimizer on correct rank process.
|
||||
"""
|
||||
# model
|
||||
# Model
|
||||
self.run_callbacks('on_pretrain_routine_start')
|
||||
ckpt = self.setup_model()
|
||||
self.model = self.model.to(self.device)
|
||||
self.set_model_attributes()
|
||||
# Check AMP
|
||||
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as they are reset by check_amp()
|
||||
self.amp = check_amp(self.model)
|
||||
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||
if world_size > 1:
|
||||
self.model = DDP(self.model, device_ids=[rank])
|
||||
@ -610,7 +613,7 @@ def check_amp(model):
|
||||
a = m(im, device=device, verbose=False)[0].boxes.boxes # FP32 inference
|
||||
with torch.cuda.amp.autocast(True):
|
||||
b = m(im, device=device, verbose=False)[0].boxes.boxes # AMP inference
|
||||
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.1) # close to 10% absolute tolerance
|
||||
return a.shape == b.shape and torch.allclose(a, b.float(), rtol=0.1) # close to 10% absolute tolerance
|
||||
|
||||
f = ROOT / 'assets/bus.jpg' # image to check
|
||||
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3))
|
||||
|
@ -17,7 +17,6 @@ from typing import Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
@ -95,8 +94,6 @@ HELP_MSG = \
|
||||
# Settings
|
||||
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
||||
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
||||
pd.options.display.max_columns = 10
|
||||
pd.options.display.width = 120
|
||||
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
||||
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
|
||||
|
@ -26,8 +26,6 @@ import platform
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.yolo.engine.exporter import export_formats
|
||||
from ultralytics.yolo.utils import LINUX, LOGGER, ROOT, SETTINGS
|
||||
@ -38,6 +36,9 @@ from ultralytics.yolo.utils.torch_utils import select_device
|
||||
|
||||
|
||||
def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, half=False, device='cpu', hard_fail=False):
|
||||
import pandas as pd
|
||||
pd.options.display.max_columns = 10
|
||||
pd.options.display.width = 120
|
||||
device = select_device(device, verbose=False)
|
||||
if isinstance(model, (str, Path)):
|
||||
model = YOLO(model)
|
||||
|
@ -8,7 +8,6 @@ import cv2
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from PIL import __version__ as pil_version
|
||||
@ -160,6 +159,7 @@ class Annotator:
|
||||
|
||||
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
||||
def plot_labels(boxes, cls, names=(), save_dir=Path('')):
|
||||
import pandas as pd
|
||||
import seaborn as sn
|
||||
|
||||
# plot dataset labels
|
||||
@ -275,7 +275,7 @@ def plot_images(images,
|
||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
||||
if paths:
|
||||
annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
||||
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
||||
if len(cls) > 0:
|
||||
idx = batch_idx == i
|
||||
|
||||
@ -330,6 +330,7 @@ def plot_images(images,
|
||||
|
||||
def plot_results(file='path/to/results.csv', dir='', segment=False):
|
||||
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
|
||||
import pandas as pd
|
||||
save_dir = Path(file).parent if file else Path(dir)
|
||||
if segment:
|
||||
fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
|
||||
|
@ -10,7 +10,7 @@ from .metrics import bbox_iou
|
||||
TORCH_1_10 = check_version(torch.__version__, '1.10.0')
|
||||
|
||||
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9, roll_out=False):
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
||||
"""select the positive anchor center in gt
|
||||
|
||||
Args:
|
||||
@ -21,14 +21,6 @@ def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9, roll_out=False):
|
||||
"""
|
||||
n_anchors = xy_centers.shape[0]
|
||||
bs, n_boxes, _ = gt_bboxes.shape
|
||||
if roll_out:
|
||||
bbox_deltas = torch.empty((bs, n_boxes, n_anchors), device=gt_bboxes.device)
|
||||
for b in range(bs):
|
||||
lt, rb = gt_bboxes[b].view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
|
||||
bbox_deltas[b] = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]),
|
||||
dim=2).view(n_boxes, n_anchors, -1).amin(2).gt_(eps)
|
||||
return bbox_deltas
|
||||
else:
|
||||
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
|
||||
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
||||
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
|
||||
@ -63,7 +55,7 @@ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
|
||||
class TaskAlignedAssigner(nn.Module):
|
||||
|
||||
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9, roll_out_thr=0):
|
||||
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
|
||||
super().__init__()
|
||||
self.topk = topk
|
||||
self.num_classes = num_classes
|
||||
@ -71,7 +63,6 @@ class TaskAlignedAssigner(nn.Module):
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.eps = eps
|
||||
self.roll_out_thr = roll_out_thr
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
||||
@ -93,7 +84,6 @@ class TaskAlignedAssigner(nn.Module):
|
||||
"""
|
||||
self.bs = pd_scores.size(0)
|
||||
self.n_max_boxes = gt_bboxes.size(1)
|
||||
self.roll_out = self.n_max_boxes > self.roll_out_thr if self.roll_out_thr else False
|
||||
|
||||
if self.n_max_boxes == 0:
|
||||
device = gt_bboxes.device
|
||||
@ -119,39 +109,34 @@ class TaskAlignedAssigner(nn.Module):
|
||||
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
||||
|
||||
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
||||
# get anchor_align metric, (b, max_num_obj, h*w)
|
||||
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes)
|
||||
# get in_gts mask, (b, max_num_obj, h*w)
|
||||
mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes, roll_out=self.roll_out)
|
||||
mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
|
||||
# get anchor_align metric, (b, max_num_obj, h*w)
|
||||
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
||||
# get topk_metric mask, (b, max_num_obj, h*w)
|
||||
mask_topk = self.select_topk_candidates(align_metric * mask_in_gts,
|
||||
topk_mask=mask_gt.repeat([1, 1, self.topk]).bool())
|
||||
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.repeat([1, 1, self.topk]).bool())
|
||||
# merge all mask to a final mask, (b, max_num_obj, h*w)
|
||||
mask_pos = mask_topk * mask_in_gts * mask_gt
|
||||
|
||||
return mask_pos, align_metric, overlaps
|
||||
|
||||
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes):
|
||||
if self.roll_out:
|
||||
align_metric = torch.empty((self.bs, self.n_max_boxes, pd_scores.shape[1]), device=pd_scores.device)
|
||||
overlaps = torch.empty((self.bs, self.n_max_boxes, pd_scores.shape[1]), device=pd_scores.device)
|
||||
ind_0 = torch.empty(self.n_max_boxes, dtype=torch.long)
|
||||
for b in range(self.bs):
|
||||
ind_0[:], ind_2 = b, gt_labels[b].squeeze(-1).long()
|
||||
# get the scores of each grid for each gt cls
|
||||
bbox_scores = pd_scores[ind_0, :, ind_2] # b, max_num_obj, h*w
|
||||
overlaps[b] = bbox_iou(gt_bboxes[b].unsqueeze(1), pd_bboxes[b].unsqueeze(0), xywh=False,
|
||||
CIoU=True).squeeze(2).clamp(0)
|
||||
align_metric[b] = bbox_scores.pow(self.alpha) * overlaps[b].pow(self.beta)
|
||||
else:
|
||||
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
||||
na = pd_bboxes.shape[-2]
|
||||
mask_gt = mask_gt.bool() # b, max_num_obj, h*w
|
||||
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
|
||||
bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
|
||||
|
||||
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
|
||||
ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj
|
||||
ind[1] = gt_labels.long().squeeze(-1) # b, max_num_obj
|
||||
# get the scores of each grid for each gt cls
|
||||
bbox_scores = pd_scores[ind[0], :, ind[1]] # b, max_num_obj, h*w
|
||||
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w
|
||||
|
||||
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
|
||||
pd_boxes = pd_bboxes.unsqueeze(1).repeat(1, self.n_max_boxes, 1, 1)[mask_gt]
|
||||
gt_boxes = gt_bboxes.unsqueeze(2).repeat(1, 1, na, 1)[mask_gt]
|
||||
overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp(0)
|
||||
|
||||
overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False,
|
||||
CIoU=True).squeeze(3).clamp(0)
|
||||
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
|
||||
return align_metric, overlaps
|
||||
|
||||
@ -170,12 +155,10 @@ class TaskAlignedAssigner(nn.Module):
|
||||
# (b, max_num_obj, topk)
|
||||
topk_idxs[~topk_mask] = 0
|
||||
# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
|
||||
if self.roll_out:
|
||||
is_in_topk = torch.empty(metrics.shape, dtype=torch.long, device=metrics.device)
|
||||
for b in range(len(topk_idxs)):
|
||||
is_in_topk[b] = F.one_hot(topk_idxs[b], num_anchors).sum(-2)
|
||||
else:
|
||||
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
|
||||
is_in_topk = torch.zeros(metrics.shape, dtype=torch.long, device=metrics.device)
|
||||
for it in range(self.topk):
|
||||
is_in_topk += F.one_hot(topk_idxs[:, :, it], num_anchors)
|
||||
# is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
|
||||
# filter invalid bboxes
|
||||
is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
|
||||
return is_in_topk.to(metrics.dtype)
|
||||
|
@ -33,6 +33,7 @@ class ClassificationValidator(BaseValidator):
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
self.metrics.speed = self.speed
|
||||
# self.metrics.confusion_matrix = self.confusion_matrix # TODO: classification ConfusionMatrix
|
||||
|
||||
def get_stats(self):
|
||||
self.metrics.process(self.targets, self.pred)
|
||||
|
@ -124,13 +124,8 @@ class Loss:
|
||||
self.device = device
|
||||
|
||||
self.use_dfl = m.reg_max > 1
|
||||
roll_out_thr = h.min_memory if h.min_memory > 1 else 64 if h.min_memory else 0 # 64 is default
|
||||
|
||||
self.assigner = TaskAlignedAssigner(topk=10,
|
||||
num_classes=self.nc,
|
||||
alpha=0.5,
|
||||
beta=6.0,
|
||||
roll_out_thr=roll_out_thr)
|
||||
self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
|
||||
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
||||
|
||||
|
@ -113,6 +113,7 @@ class DetectionValidator(BaseValidator):
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
def get_stats(self):
|
||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
|
||||
|
@ -121,6 +121,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user