ultralytics 8.1.17 fix ClassificationDataset caching (#8358)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-02-21 11:48:36 +01:00 committed by GitHub
parent 604b9d0794
commit 2945cfc6ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 24 additions and 17 deletions

View File

@ -197,7 +197,7 @@ nav:
- Python: usage/python.md - Python: usage/python.md
- Callbacks: usage/callbacks.md - Callbacks: usage/callbacks.md
- Configuration: usage/cfg.md - Configuration: usage/cfg.md
- Simple-Utilities: usage/simple-utilities.md - Simple Utilities: usage/simple-utilities.md
- Advanced Customization: usage/engine.md - Advanced Customization: usage/engine.md
- Modes: - Modes:
- modes/index.md - modes/index.md

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.16" __version__ = "8.1.17"
from ultralytics.data.explorer.explorer import Explorer from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

View File

@ -226,35 +226,42 @@ class YOLODataset(BaseDataset):
# Classification dataloaders ------------------------------------------------------------------------------------------- # Classification dataloaders -------------------------------------------------------------------------------------------
class ClassificationDataset(torchvision.datasets.ImageFolder): class ClassificationDataset(torchvision.datasets.ImageFolder):
""" """
YOLO Classification Dataset. Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
learning models, with optional image transformations and caching mechanisms to speed up training.
Args: This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
root (str): Dataset path. in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
to ensure data integrity and consistency.
Attributes: Attributes:
cache_ram (bool): True if images should be cached in RAM, False otherwise. cache_ram (bool): Indicates if caching in RAM is enabled.
cache_disk (bool): True if images should be cached on disk, False otherwise. cache_disk (bool): Indicates if caching on disk is enabled.
samples (list): List of samples containing file, index, npy, and im. samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
torch_transforms (callable): torchvision transforms applied to the dataset. file (if caching on disk), and optionally the loaded image array (if caching in RAM).
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True. torch_transforms (callable): PyTorch transforms to be applied to the images.
""" """
def __init__(self, root, args, augment=False, cache=False, prefix=""): def __init__(self, root, args, augment=False, prefix=""):
""" """
Initialize YOLO object with root, image size, augmentations, and cache settings. Initialize YOLO object with root, image size, augmentations, and cache settings.
Args: Args:
root (str): Dataset path. root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
args (Namespace): Argument parser containing dataset related settings. args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False. parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
cache (bool | str | optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False. of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
debugging. Default is an empty string.
""" """
super().__init__(root=root) super().__init__(root=root)
if augment and args.fraction < 1.0: # reduce training fraction if augment and args.fraction < 1.0: # reduce training fraction
self.samples = self.samples[: round(len(self.samples) * args.fraction)] self.samples = self.samples[: round(len(self.samples) * args.fraction)]
self.prefix = colorstr(f"{prefix}: ") if prefix else "" self.prefix = colorstr(f"{prefix}: ") if prefix else ""
self.cache_ram = cache is True or cache == "ram" self.cache_ram = args.cache is True or args.cache == "ram" # cache images into RAM
self.cache_disk = cache == "disk" self.cache_disk = args.cache == "disk" # cache images on hard drive as uncompressed *.npy files
self.samples = self.verify_images() # filter out bad images self.samples = self.verify_images() # filter out bad images
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0) scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)