mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
ultralytics 8.0.160
Classify dataset scanning and caching (#4502)
This commit is contained in:
parent
b890e1c937
commit
c7ceb84fb6
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = '8.0.159'
|
__version__ = '8.0.160'
|
||||||
|
|
||||||
from ultralytics.models import RTDETR, SAM, YOLO
|
from ultralytics.models import RTDETR, SAM, YOLO
|
||||||
from ultralytics.models.fastsam import FastSAM
|
from ultralytics.models.fastsam import FastSAM
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
import contextlib
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -10,11 +10,14 @@ import torch
|
|||||||
import torchvision
|
import torchvision
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
|
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, colorstr, is_dir_writeable
|
||||||
|
|
||||||
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
|
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
|
||||||
from .base import BaseDataset
|
from .base import BaseDataset
|
||||||
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label
|
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
|
||||||
|
|
||||||
|
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
|
||||||
|
DATASET_CACHE_VERSION = '1.0.2'
|
||||||
|
|
||||||
|
|
||||||
class YOLODataset(BaseDataset):
|
class YOLODataset(BaseDataset):
|
||||||
@ -29,7 +32,6 @@ class YOLODataset(BaseDataset):
|
|||||||
Returns:
|
Returns:
|
||||||
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model.
|
||||||
"""
|
"""
|
||||||
cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
|
|
||||||
|
|
||||||
def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs):
|
def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs):
|
||||||
self.use_segments = use_segments
|
self.use_segments = use_segments
|
||||||
@ -87,15 +89,7 @@ class YOLODataset(BaseDataset):
|
|||||||
x['hash'] = get_hash(self.label_files + self.im_files)
|
x['hash'] = get_hash(self.label_files + self.im_files)
|
||||||
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
||||||
x['msgs'] = msgs # warnings
|
x['msgs'] = msgs # warnings
|
||||||
x['version'] = self.cache_version # cache version
|
save_dataset_cache_file(self.prefix, path, x)
|
||||||
if is_dir_writeable(path.parent):
|
|
||||||
if path.exists():
|
|
||||||
path.unlink() # remove *.cache file if exists
|
|
||||||
np.save(str(path), x) # save cache for next time
|
|
||||||
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
|
|
||||||
LOGGER.info(f'{self.prefix}New cache created: {path}')
|
|
||||||
else:
|
|
||||||
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
@ -103,11 +97,8 @@ class YOLODataset(BaseDataset):
|
|||||||
self.label_files = img2label_paths(self.im_files)
|
self.label_files = img2label_paths(self.im_files)
|
||||||
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
|
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
|
||||||
try:
|
try:
|
||||||
import gc
|
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
||||||
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
assert cache['version'] == DATASET_CACHE_VERSION # matches current version
|
||||||
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
|
|
||||||
gc.enable()
|
|
||||||
assert cache['version'] == self.cache_version # matches current version
|
|
||||||
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
|
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
|
||||||
except (FileNotFoundError, AssertionError, AttributeError):
|
except (FileNotFoundError, AssertionError, AttributeError):
|
||||||
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
||||||
@ -116,7 +107,7 @@ class YOLODataset(BaseDataset):
|
|||||||
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
|
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
|
||||||
if exists and LOCAL_RANK in (-1, 0):
|
if exists and LOCAL_RANK in (-1, 0):
|
||||||
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt'
|
||||||
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
|
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display results
|
||||||
if cache['msgs']:
|
if cache['msgs']:
|
||||||
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
||||||
if nf == 0: # number of labels found
|
if nf == 0: # number of labels found
|
||||||
@ -216,7 +207,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||||||
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
|
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, root, args, augment=False, cache=False):
|
def __init__(self, root, args, augment=False, cache=False, prefix=''):
|
||||||
"""
|
"""
|
||||||
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
||||||
|
|
||||||
@ -229,8 +220,10 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||||||
super().__init__(root=root)
|
super().__init__(root=root)
|
||||||
if augment and args.fraction < 1.0: # reduce training fraction
|
if augment and args.fraction < 1.0: # reduce training fraction
|
||||||
self.samples = self.samples[:round(len(self.samples) * args.fraction)]
|
self.samples = self.samples[:round(len(self.samples) * args.fraction)]
|
||||||
|
self.prefix = colorstr(f'{prefix}: ') if prefix else ''
|
||||||
self.cache_ram = cache is True or cache == 'ram'
|
self.cache_ram = cache is True or cache == 'ram'
|
||||||
self.cache_disk = cache == 'disk'
|
self.cache_disk = cache == 'disk'
|
||||||
|
self.samples = self.verify_images() # filter out bad images
|
||||||
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
||||||
self.torch_transforms = classify_transforms(args.imgsz)
|
self.torch_transforms = classify_transforms(args.imgsz)
|
||||||
self.album_transforms = classify_albumentations(
|
self.album_transforms = classify_albumentations(
|
||||||
@ -266,6 +259,67 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.samples)
|
return len(self.samples)
|
||||||
|
|
||||||
|
def verify_images(self):
|
||||||
|
"""Verify all images in dataset."""
|
||||||
|
desc = f'{self.prefix}Scanning {self.root}...'
|
||||||
|
path = Path(self.root).with_suffix('.cache') # *.cache file path
|
||||||
|
|
||||||
|
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
||||||
|
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
||||||
|
assert cache['version'] == DATASET_CACHE_VERSION # matches current version
|
||||||
|
assert cache['hash'] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||||
|
nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total
|
||||||
|
if LOCAL_RANK in (-1, 0):
|
||||||
|
d = f'{desc} {nf} images, {nc} corrupt'
|
||||||
|
tqdm(None, desc=d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT)
|
||||||
|
if cache['msgs']:
|
||||||
|
LOGGER.info('\n'.join(cache['msgs'])) # display warnings
|
||||||
|
return samples
|
||||||
|
|
||||||
|
# Run scan if *.cache retrieval failed
|
||||||
|
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
||||||
|
with ThreadPool(NUM_THREADS) as pool:
|
||||||
|
results = pool.imap(func=verify_image, iterable=zip([x[0] for x in self.samples], repeat(self.prefix)))
|
||||||
|
pbar = tqdm(results, desc=desc, total=len(self.samples), bar_format=TQDM_BAR_FORMAT)
|
||||||
|
for im_file, nf_f, nc_f, msg in pbar:
|
||||||
|
if nf_f:
|
||||||
|
samples.append((im_file, nf))
|
||||||
|
if msg:
|
||||||
|
msgs.append(msg)
|
||||||
|
nf += nf_f
|
||||||
|
nc += nc_f
|
||||||
|
pbar.desc = f'{desc} {nf} images, {nc} corrupt'
|
||||||
|
pbar.close()
|
||||||
|
if msgs:
|
||||||
|
LOGGER.info('\n'.join(msgs))
|
||||||
|
x['hash'] = get_hash([x[0] for x in self.samples])
|
||||||
|
x['results'] = nf, nc, len(samples), samples
|
||||||
|
x['msgs'] = msgs # warnings
|
||||||
|
save_dataset_cache_file(self.prefix, path, x)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset_cache_file(path):
|
||||||
|
"""Load an Ultralytics *.cache dictionary from path."""
|
||||||
|
import gc
|
||||||
|
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
||||||
|
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
||||||
|
gc.enable()
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
def save_dataset_cache_file(prefix, path, x):
|
||||||
|
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
||||||
|
x['version'] = DATASET_CACHE_VERSION # add cache version
|
||||||
|
if is_dir_writeable(path.parent):
|
||||||
|
if path.exists():
|
||||||
|
path.unlink() # remove *.cache file if exists
|
||||||
|
np.save(str(path), x) # save cache for next time
|
||||||
|
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
|
||||||
|
LOGGER.info(f'{prefix}New cache created: {path}')
|
||||||
|
else:
|
||||||
|
LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.')
|
||||||
|
|
||||||
|
|
||||||
# TODO: support semantic segmentation
|
# TODO: support semantic segmentation
|
||||||
class SemanticDataset(BaseDataset):
|
class SemanticDataset(BaseDataset):
|
||||||
|
@ -57,6 +57,31 @@ def exif_size(img: Image.Image):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def verify_image(args):
|
||||||
|
"""Verify one image."""
|
||||||
|
im_file, prefix = args
|
||||||
|
# Number (found, corrupt), message
|
||||||
|
nf, nc, msg = 0, 0, ''
|
||||||
|
try:
|
||||||
|
im = Image.open(im_file)
|
||||||
|
im.verify() # PIL verify
|
||||||
|
shape = exif_size(im) # image size
|
||||||
|
shape = (shape[1], shape[0]) # hw
|
||||||
|
assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
|
||||||
|
assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
|
||||||
|
if im.format.lower() in ('jpg', 'jpeg'):
|
||||||
|
with open(im_file, 'rb') as f:
|
||||||
|
f.seek(-2, 2)
|
||||||
|
if f.read() != b'\xff\xd9': # corrupt JPEG
|
||||||
|
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
|
||||||
|
msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
|
||||||
|
nf = 1
|
||||||
|
except Exception as e:
|
||||||
|
nc = 1
|
||||||
|
msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
|
||||||
|
return im_file, nf, nc, msg
|
||||||
|
|
||||||
|
|
||||||
def verify_image_label(args):
|
def verify_image_label(args):
|
||||||
"""Verify one image-label pair."""
|
"""Verify one image-label pair."""
|
||||||
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
|
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args
|
||||||
|
@ -79,7 +79,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
return ckpt
|
return ckpt
|
||||||
|
|
||||||
def build_dataset(self, img_path, mode='train', batch=None):
|
def build_dataset(self, img_path, mode='train', batch=None):
|
||||||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
|
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode)
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||||
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
||||||
|
@ -77,7 +77,7 @@ class ClassificationValidator(BaseValidator):
|
|||||||
return self.metrics.results_dict
|
return self.metrics.results_dict
|
||||||
|
|
||||||
def build_dataset(self, img_path):
|
def build_dataset(self, img_path):
|
||||||
return ClassificationDataset(root=img_path, args=self.args, augment=False)
|
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size):
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
"""Builds and returns a data loader for classification tasks with given parameters."""
|
"""Builds and returns a data loader for classification tasks with given parameters."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user