mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
ultralytics 8.0.43
optimized Results
class and fixes (#1069)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alexander Duda <Alexander.Duda@me.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
f2a7a29e53
commit
fe61018975
15
.github/workflows/ci.yaml
vendored
15
.github/workflows/ci.yaml
vendored
@ -18,7 +18,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ['3.10'] # requires python<=3.9
|
||||
python-version: ['3.10']
|
||||
model: [yolov8n]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
@ -51,17 +51,17 @@ jobs:
|
||||
shell: python
|
||||
run: |
|
||||
from ultralytics.yolo.utils.benchmarks import run_benchmarks
|
||||
run_benchmarks(model='yolov8n.pt', imgsz=160, half=False, hard_fail=False)
|
||||
run_benchmarks(model='${{ matrix.model }}.pt', imgsz=160, half=False, hard_fail=False)
|
||||
- name: Benchmark SegmentationModel
|
||||
shell: python
|
||||
run: |
|
||||
from ultralytics.yolo.utils.benchmarks import run_benchmarks
|
||||
run_benchmarks(model='yolov8n-seg.pt', imgsz=160, half=False, hard_fail=False)
|
||||
run_benchmarks(model='${{ matrix.model }}-seg.pt', imgsz=160, half=False, hard_fail=False)
|
||||
- name: Benchmark ClassificationModel
|
||||
shell: python
|
||||
run: |
|
||||
from ultralytics.yolo.utils.benchmarks import run_benchmarks
|
||||
run_benchmarks(model='yolov8n-cls.pt', imgsz=160, half=False, hard_fail=False)
|
||||
run_benchmarks(model='${{ matrix.model }}-cls.pt', imgsz=160, half=False, hard_fail=False)
|
||||
|
||||
Tests:
|
||||
timeout-minutes: 60
|
||||
@ -70,13 +70,10 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ['3.10']
|
||||
python-version: ['3.7', '3.8', '3.9', '3.10']
|
||||
model: [yolov8n]
|
||||
torch: [latest]
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
python-version: '3.7' # '3.6.8' min
|
||||
model: yolov8n
|
||||
- os: ubuntu-latest
|
||||
python-version: '3.8' # torch 1.7.0 requires python >=3.6, <=3.8
|
||||
model: yolov8n
|
||||
@ -123,9 +120,7 @@ jobs:
|
||||
run: |
|
||||
import os
|
||||
import ultralytics
|
||||
from ultralytics import hub, yolo
|
||||
key = os.environ['APIKEY']
|
||||
print(ultralytics.__version__)
|
||||
ultralytics.checks()
|
||||
# ultralytics.reset_model(key) # reset trained model
|
||||
# ultralytics.start(key) # train model
|
||||
|
@ -130,6 +130,7 @@ given task.
|
||||
| half | False | use half precision (FP16) |
|
||||
| device | null | device to run on, i.e. cuda device=0/1/2/3 or device=cpu |
|
||||
| show | False | show results if possible |
|
||||
| save | False | save images with results |
|
||||
| save_txt | False | save results as .txt file |
|
||||
| save_conf | False | save results with confidence scores |
|
||||
| save_crop | False | save cropped images with results |
|
||||
|
@ -3,7 +3,7 @@
|
||||
# Local usage: pip install pre-commit, pre-commit run --all-files
|
||||
|
||||
[metadata]
|
||||
license_file = LICENSE
|
||||
license_files = LICENSE
|
||||
description_file = README.md
|
||||
|
||||
[tool:pytest]
|
||||
|
@ -59,7 +59,7 @@ def test_segment():
|
||||
# Predictor
|
||||
pred = segment.SegmentationPredictor(overrides={'imgsz': [64, 64]})
|
||||
result = pred(source=SOURCE, model=f'{MODEL}-seg.pt')
|
||||
assert len(result) == 2, 'predictor test failed'
|
||||
assert len(result), 'predictor test failed'
|
||||
|
||||
# Test resume
|
||||
overrides['resume'] = trainer.last
|
||||
@ -97,4 +97,4 @@ def test_classify():
|
||||
# Predictor
|
||||
pred = classify.ClassificationPredictor(overrides={'imgsz': [64, 64]})
|
||||
result = pred(source=SOURCE, model=trainer.best)
|
||||
assert len(result) == 2, 'predictor test failed'
|
||||
assert len(result), 'predictor test failed'
|
||||
|
@ -14,6 +14,13 @@ from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS
|
||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
||||
CFG = 'yolov8n.yaml'
|
||||
SOURCE = ROOT / 'assets/bus.jpg'
|
||||
SOURCE_GREYSCALE = Path(f'{SOURCE.parent / SOURCE.stem}_greyscale.jpg')
|
||||
SOURCE_RGBA = Path(f'{SOURCE.parent / SOURCE.stem}_4ch.png')
|
||||
|
||||
# Convert SOURCE to greyscale and 4-ch
|
||||
im = Image.open(SOURCE)
|
||||
im.convert('L').save(SOURCE_GREYSCALE) # greyscale
|
||||
im.convert('RGBA').save(SOURCE_RGBA) # 4-ch PNG with alpha
|
||||
|
||||
|
||||
def test_model_forward():
|
||||
@ -42,8 +49,7 @@ def test_predict_dir():
|
||||
|
||||
def test_predict_img():
|
||||
model = YOLO(MODEL)
|
||||
img = Image.open(str(SOURCE))
|
||||
output = model(source=img, save=True, verbose=True) # PIL
|
||||
output = model(source=Image.open(SOURCE), save=True, verbose=True) # PIL
|
||||
assert len(output) == 1, 'predict test failed'
|
||||
img = cv2.imread(str(SOURCE))
|
||||
output = model(source=img, save=True, save_txt=True) # ndarray
|
||||
@ -67,6 +73,13 @@ def test_predict_img():
|
||||
assert len(output) == 6, 'predict test failed!'
|
||||
|
||||
|
||||
def test_predict_grey_and_4ch():
|
||||
model = YOLO(MODEL)
|
||||
for f in SOURCE_RGBA, SOURCE_GREYSCALE:
|
||||
for source in Image.open(f), cv2.imread(str(f)), f:
|
||||
model(source, save=True, verbose=True)
|
||||
|
||||
|
||||
def test_val():
|
||||
model = YOLO(MODEL)
|
||||
model.val(data='coco8.yaml', imgsz=32)
|
||||
@ -151,6 +164,7 @@ def test_predict_callback_and_setup():
|
||||
# results -> List[batch_size]
|
||||
path, _, im0s, _, _ = predictor.batch
|
||||
# print('on_predict_batch_end', im0s[0].shape)
|
||||
im0s = im0s if isinstance(im0s, list) else [im0s]
|
||||
bs = [predictor.dataset.bs for _ in range(len(path))]
|
||||
predictor.results = zip(predictor.results, im0s, bs)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
__version__ = '8.0.42'
|
||||
__version__ = '8.0.43'
|
||||
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils.checks import check_yolo as checks
|
||||
|
@ -1,3 +1,4 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
from .track import register_tracker
|
||||
from .trackers import BOTSORT, BYTETracker
|
||||
|
@ -1,13 +1,16 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
||||
|
||||
check_requirements('lap') # for linear_assignment
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.tracker import BOTSORT, BYTETracker
|
||||
from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
||||
|
||||
from .trackers import BOTSORT, BYTETracker
|
||||
|
||||
TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
|
||||
check_requirements('lap') # for linear_assignment
|
||||
|
||||
|
||||
def on_predict_start(predictor):
|
||||
|
@ -18,7 +18,7 @@ CLI_HELP_MSG = \
|
||||
yolo TASK MODE ARGS
|
||||
|
||||
Where TASK (optional) is one of [detect, segment, classify]
|
||||
MODE (required) is one of [train, val, predict, export]
|
||||
MODE (required) is one of [train, val, predict, export, track]
|
||||
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
||||
See all ARGS at https://docs.ultralytics.com/cfg or with 'yolo cfg'
|
||||
|
||||
@ -197,7 +197,7 @@ def entrypoint(debug=''):
|
||||
|
||||
# Define tasks and modes
|
||||
tasks = 'detect', 'segment', 'classify'
|
||||
modes = 'train', 'val', 'predict', 'export', 'track'
|
||||
modes = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
||||
special = {
|
||||
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
||||
'checks': checks.check_yolo,
|
||||
|
@ -290,13 +290,15 @@ class LoadPilAndNumpy:
|
||||
self.transforms = transforms
|
||||
self.mode = 'image'
|
||||
# generate fake paths
|
||||
self.paths = [f'image{i}.jpg' for i in range(len(self.im0))]
|
||||
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(self.im0)]
|
||||
self.bs = len(self.im0)
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im):
|
||||
assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
|
||||
if isinstance(im, Image.Image):
|
||||
if im.mode != 'RGB':
|
||||
im = im.convert('RGB')
|
||||
im = np.asarray(im)[:, :, ::-1]
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
return im
|
||||
|
@ -1045,7 +1045,7 @@ class HUBDatasetStats():
|
||||
autodownload: Attempt to download dataset if not found locally
|
||||
|
||||
Usage
|
||||
from utils.dataloaders import HUBDatasetStats
|
||||
from ultralytics.yolo.data.dataloaders.v5loader import HUBDatasetStats
|
||||
stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1
|
||||
stats = HUBDatasetStats('path/to/coco128.zip') # usage 2
|
||||
stats.get_json(save=False)
|
||||
@ -1055,15 +1055,15 @@ class HUBDatasetStats():
|
||||
def __init__(self, path='coco128.yaml', autodownload=False):
|
||||
# Initialize class
|
||||
zipped, data_dir, yaml_path = self._unzip(Path(path))
|
||||
try:
|
||||
data = yaml_load(check_yaml(yaml_path)) # data dict
|
||||
if zipped:
|
||||
data['path'] = data_dir
|
||||
except Exception as e:
|
||||
raise Exception('error/HUB/dataset_stats/yaml_load') from e
|
||||
# try:
|
||||
# data = yaml_load(check_yaml(yaml_path)) # data dict
|
||||
# if zipped:
|
||||
# data['path'] = data_dir
|
||||
# except Exception as e:
|
||||
# raise Exception('error/HUB/dataset_stats/yaml_load') from e
|
||||
|
||||
check_det_dataset(data, autodownload) # download dataset if missing
|
||||
self.hub_dir = Path(data['path'] + '-hub')
|
||||
data = check_det_dataset(yaml_path, autodownload) # download dataset if missing
|
||||
self.hub_dir = Path(str(data['path']) + '-hub')
|
||||
self.im_dir = self.hub_dir / 'images'
|
||||
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
||||
self.stats = {'nc': data['nc'], 'names': list(data['names'].values())} # statistics dictionary
|
||||
|
@ -9,7 +9,7 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
|
||||
guess_model_task, nn)
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, callbacks, yaml_load
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
@ -203,7 +203,7 @@ class YOLO:
|
||||
|
||||
@smart_inference_mode()
|
||||
def track(self, source=None, stream=False, **kwargs):
|
||||
from ultralytics.tracker.track import register_tracker
|
||||
from ultralytics.tracker import register_tracker
|
||||
register_tracker(self)
|
||||
# ByteTrack-based method needs low confidence predictions as input
|
||||
conf = kwargs.get('conf') or 0.1
|
||||
@ -237,6 +237,20 @@ class YOLO:
|
||||
|
||||
return validator.metrics
|
||||
|
||||
@smart_inference_mode()
|
||||
def benchmark(self, **kwargs):
|
||||
"""
|
||||
Benchmark a model on all export formats.
|
||||
|
||||
Args:
|
||||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
from ultralytics.yolo.utils.benchmarks import run_benchmarks
|
||||
overrides = self.model.args.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
|
||||
return run_benchmarks(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device'])
|
||||
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
Export model.
|
||||
|
@ -194,7 +194,7 @@ class BasePredictor:
|
||||
|
||||
# Print time (inference-only)
|
||||
if self.args.verbose:
|
||||
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
|
||||
LOGGER.info(f'{s}{self.dt[1].dt * 1E3:.1f}ms')
|
||||
|
||||
# Release assets
|
||||
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
|
||||
|
@ -1,3 +1,10 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
"""
|
||||
Ultralytics Results, Boxes and Masks classes for handling inference results
|
||||
|
||||
Usage: See https://docs.ultralytics.com/predict/
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import lru_cache
|
||||
|
||||
@ -36,7 +43,7 @@ class Results:
|
||||
self.probs = probs if probs is not None else None
|
||||
self.names = names
|
||||
self.path = path
|
||||
self.comp = ['boxes', 'masks', 'probs']
|
||||
self._keys = (k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None)
|
||||
|
||||
def pandas(self):
|
||||
pass
|
||||
@ -44,10 +51,8 @@ class Results:
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
setattr(r, item, getattr(self, item)[idx])
|
||||
for k in self._keys:
|
||||
setattr(r, k, getattr(self, k)[idx])
|
||||
return r
|
||||
|
||||
def update(self, boxes=None, masks=None, probs=None):
|
||||
@ -60,57 +65,37 @@ class Results:
|
||||
|
||||
def cpu(self):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
setattr(r, item, getattr(self, item).cpu())
|
||||
for k in self._keys:
|
||||
setattr(r, k, getattr(self, k).cpu())
|
||||
return r
|
||||
|
||||
def numpy(self):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
setattr(r, item, getattr(self, item).numpy())
|
||||
for k in self._keys:
|
||||
setattr(r, k, getattr(self, k).numpy())
|
||||
return r
|
||||
|
||||
def cuda(self):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
setattr(r, item, getattr(self, item).cuda())
|
||||
for k in self._keys:
|
||||
setattr(r, k, getattr(self, k).cuda())
|
||||
return r
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
r = Results(orig_img=self.orig_img, path=self.path, names=self.names)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
setattr(r, item, getattr(self, item).to(*args, **kwargs))
|
||||
for k in self._keys:
|
||||
setattr(r, k, getattr(self, k).to(*args, **kwargs))
|
||||
return r
|
||||
|
||||
def __len__(self):
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
return len(getattr(self, item))
|
||||
for k in self._keys:
|
||||
return len(getattr(self, k))
|
||||
|
||||
def __str__(self):
|
||||
str_out = ''
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
str_out = str_out + getattr(self, item).__str__()
|
||||
return str_out
|
||||
return ''.join(getattr(self, k).__str__() for k in self._keys)
|
||||
|
||||
def __repr__(self):
|
||||
str_out = ''
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
str_out = str_out + getattr(self, item).__repr__()
|
||||
return str_out
|
||||
return ''.join(getattr(self, k).__repr__() for k in self._keys)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
@ -226,20 +211,16 @@ class Boxes:
|
||||
return self.xywh / self.orig_shape[[1, 0, 1, 0]]
|
||||
|
||||
def cpu(self):
|
||||
boxes = self.boxes.cpu()
|
||||
return Boxes(boxes, self.orig_shape)
|
||||
return Boxes(self.boxes.cpu(), self.orig_shape)
|
||||
|
||||
def numpy(self):
|
||||
boxes = self.boxes.numpy()
|
||||
return Boxes(boxes, self.orig_shape)
|
||||
return Boxes(self.boxes.numpy(), self.orig_shape)
|
||||
|
||||
def cuda(self):
|
||||
boxes = self.boxes.cuda()
|
||||
return Boxes(boxes, self.orig_shape)
|
||||
return Boxes(self.boxes.cuda(), self.orig_shape)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
boxes = self.boxes.to(*args, **kwargs)
|
||||
return Boxes(boxes, self.orig_shape)
|
||||
return Boxes(self.boxes.to(*args, **kwargs), self.orig_shape)
|
||||
|
||||
def pandas(self):
|
||||
LOGGER.info('results.pandas() method not yet implemented')
|
||||
@ -272,8 +253,7 @@ class Boxes:
|
||||
f'shape: {self.boxes.shape}\n' + f'dtype: {self.boxes.dtype}\n + {self.boxes.__repr__()}')
|
||||
|
||||
def __getitem__(self, idx):
|
||||
boxes = self.boxes[idx]
|
||||
return Boxes(boxes, self.orig_shape)
|
||||
return Boxes(self.boxes[idx], self.orig_shape)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
@ -331,20 +311,16 @@ class Masks:
|
||||
return self.masks
|
||||
|
||||
def cpu(self):
|
||||
masks = self.masks.cpu()
|
||||
return Masks(masks, self.orig_shape)
|
||||
return Masks(self.masks.cpu(), self.orig_shape)
|
||||
|
||||
def numpy(self):
|
||||
masks = self.masks.numpy()
|
||||
return Masks(masks, self.orig_shape)
|
||||
return Masks(self.masks.numpy(), self.orig_shape)
|
||||
|
||||
def cuda(self):
|
||||
masks = self.masks.cuda()
|
||||
return Masks(masks, self.orig_shape)
|
||||
return Masks(self.masks.cuda(), self.orig_shape)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
masks = self.masks.to(*args, **kwargs)
|
||||
return Masks(masks, self.orig_shape)
|
||||
return Masks(self.masks.to(*args, **kwargs), self.orig_shape)
|
||||
|
||||
def __len__(self): # override len(results)
|
||||
return len(self.masks)
|
||||
@ -357,8 +333,7 @@ class Masks:
|
||||
f'shape: {self.masks.shape}\n' + f'dtype: {self.masks.dtype}\n + {self.masks.__repr__()}')
|
||||
|
||||
def __getitem__(self, idx):
|
||||
masks = self.masks[idx]
|
||||
return Masks(masks, self.orig_shape)
|
||||
return Masks(self.masks[idx], self.orig_shape)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
|
@ -243,6 +243,8 @@ class BaseTrainer:
|
||||
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix='val')
|
||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||
self.ema = ModelEMA(self.model)
|
||||
if self.args.plots:
|
||||
self.plot_training_labels()
|
||||
self.resume_training(ckpt)
|
||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||
self.run_callbacks('on_pretrain_routine_end')
|
||||
@ -501,6 +503,9 @@ class BaseTrainer:
|
||||
def plot_training_samples(self, batch, ni):
|
||||
pass
|
||||
|
||||
def plot_training_labels(self):
|
||||
pass
|
||||
|
||||
def save_metrics(self, metrics):
|
||||
keys, vals = list(metrics.keys()), list(metrics.values())
|
||||
n = len(metrics) + 1 # number of cols
|
||||
|
@ -28,7 +28,7 @@ from tqdm import tqdm
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, emojis
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
@ -194,6 +194,8 @@ class BaseValidator:
|
||||
self.logger.info(f'Saving {f.name}...')
|
||||
json.dump(self.jdict, f) # flatten and save
|
||||
stats = self.eval_json(stats) # update stats
|
||||
if self.args.plots or self.args.save_json:
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
return stats
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
|
@ -27,13 +27,14 @@ import time
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.yolo.engine.exporter import export_formats
|
||||
from ultralytics.yolo.utils import LOGGER, SETTINGS
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, SETTINGS
|
||||
from ultralytics.yolo.utils.checks import check_yolo
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.yolo.utils.files import file_size
|
||||
from ultralytics.yolo.utils.torch_utils import select_device
|
||||
|
||||
|
||||
def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
||||
@ -41,8 +42,9 @@ def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
||||
half=False,
|
||||
device='cpu',
|
||||
hard_fail=False):
|
||||
device = torch.device(int(device) if device.isnumeric() else device)
|
||||
model = YOLO(model)
|
||||
device = select_device(device, verbose=False)
|
||||
if isinstance(model, (str, Path)):
|
||||
model = YOLO(model)
|
||||
|
||||
y = []
|
||||
t0 = time.time()
|
||||
@ -65,6 +67,11 @@ def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
||||
export = YOLO(filename)
|
||||
assert suffix in str(filename), 'export failed'
|
||||
|
||||
# Predict
|
||||
if not (ROOT / 'assets/bus.jpg').exists():
|
||||
download(url='https://ultralytics.com/images/bus.jpg', dir=ROOT / 'assets')
|
||||
export.predict(ROOT / 'assets/bus.jpg', imgsz=imgsz, device=device, half=half) # test
|
||||
|
||||
# Validate
|
||||
if model.task == 'detect':
|
||||
data, key = 'coco128.yaml', 'metrics/mAP50-95(B)'
|
||||
@ -96,6 +103,8 @@ def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
||||
floor = eval(hard_fail) # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
|
||||
assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: metric < floor {floor}'
|
||||
|
||||
return df
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_benchmarks()
|
||||
|
@ -5,19 +5,24 @@ import math
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sn
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from PIL import __version__ as pil_version
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, threaded
|
||||
from ultralytics.yolo.utils import LOGGER, TryExcept, threaded
|
||||
|
||||
from .checks import check_font, check_version, is_ascii
|
||||
from .files import increment_path
|
||||
from .ops import clip_coords, scale_image, xywh2xyxy, xyxy2xywh
|
||||
|
||||
matplotlib.rc('font', **{'size': 11})
|
||||
matplotlib.use('Agg') # for writing to files only
|
||||
|
||||
|
||||
class Colors:
|
||||
# Ultralytics color palette https://ultralytics.com/
|
||||
@ -152,6 +157,52 @@ class Annotator:
|
||||
return np.asarray(self.im)
|
||||
|
||||
|
||||
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
||||
def plot_labels(boxes, cls, names=(), save_dir=Path('')):
|
||||
# plot dataset labels
|
||||
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
||||
b = boxes.transpose() # classes, boxes
|
||||
nc = int(cls.max() + 1) # number of classes
|
||||
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
|
||||
|
||||
# seaborn correlogram
|
||||
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
||||
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
|
||||
plt.close()
|
||||
|
||||
# matplotlib labels
|
||||
matplotlib.use('svg') # faster
|
||||
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
||||
y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
||||
with contextlib.suppress(Exception): # color histogram bars by class
|
||||
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
|
||||
ax[0].set_ylabel('instances')
|
||||
if 0 < len(names) < 30:
|
||||
ax[0].set_xticks(range(len(names)))
|
||||
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
|
||||
else:
|
||||
ax[0].set_xlabel('classes')
|
||||
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
|
||||
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
|
||||
|
||||
# rectangles
|
||||
boxes[:, 0:2] = 0.5 # center
|
||||
boxes = xywh2xyxy(boxes) * 2000
|
||||
img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
|
||||
for cls, box in zip(cls[:1000], boxes[:1000]):
|
||||
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
|
||||
ax[1].imshow(img)
|
||||
ax[1].axis('off')
|
||||
|
||||
for a in [0, 1, 2, 3]:
|
||||
for s in ['top', 'right', 'left', 'bottom']:
|
||||
ax[a].spines[s].set_visible(False)
|
||||
|
||||
plt.savefig(save_dir / 'labels.jpg', dpi=200)
|
||||
matplotlib.use('Agg')
|
||||
plt.close()
|
||||
|
||||
|
||||
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
||||
# Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
|
||||
xyxy = torch.Tensor(xyxy).view(-1, 4)
|
||||
|
@ -59,7 +59,7 @@ def DDP_model(model):
|
||||
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
||||
|
||||
|
||||
def select_device(device='', batch=0, newline=False):
|
||||
def select_device(device='', batch=0, newline=False, verbose=True):
|
||||
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
||||
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
||||
device = str(device).lower()
|
||||
@ -102,7 +102,7 @@ def select_device(device='', batch=0, newline=False):
|
||||
s += 'CPU\n'
|
||||
arg = 'cpu'
|
||||
|
||||
if RANK == -1:
|
||||
if verbose and RANK == -1:
|
||||
LOGGER.info(s if newline else s.rstrip())
|
||||
return torch.device(arg)
|
||||
|
||||
|
@ -56,7 +56,7 @@ class DetectionPredictor(BasePredictor):
|
||||
|
||||
det = results[idx].boxes # TODO: make boxes inherit from tensors
|
||||
if len(det) == 0:
|
||||
return log_string
|
||||
return f'{log_string}(no detections), '
|
||||
for c in det.cls.unique():
|
||||
n = (det.cls == c).sum() # detections per class
|
||||
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
from copy import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -12,7 +13,7 @@ from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, RANK, colorstr
|
||||
from ultralytics.yolo.utils.loss import BboxLoss
|
||||
from ultralytics.yolo.utils.ops import xywh2xyxy
|
||||
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
||||
from ultralytics.yolo.utils.plotting import plot_images, plot_labels, plot_results
|
||||
from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
|
||||
@ -102,6 +103,11 @@ class DetectionTrainer(BaseTrainer):
|
||||
def plot_metrics(self):
|
||||
plot_results(file=self.csv) # save results.png
|
||||
|
||||
def plot_training_labels(self):
|
||||
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
|
||||
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
|
||||
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir)
|
||||
|
||||
|
||||
# Criterion class for computing training losses
|
||||
class Loss:
|
||||
|
@ -59,7 +59,7 @@ class SegmentationPredictor(DetectionPredictor):
|
||||
|
||||
result = results[idx]
|
||||
if len(result) == 0:
|
||||
return log_string
|
||||
return f'{log_string}(no detections), '
|
||||
det, mask = result.boxes, result.masks # getting tensors TODO: mask mask,box inherit for tensor
|
||||
|
||||
# Print results
|
||||
|
Loading…
x
Reference in New Issue
Block a user