mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Simplify augmentations (#93)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
249dfbdc05
commit
ae05d44877
@ -1,4 +1,3 @@
|
||||
import collections
|
||||
import math
|
||||
import random
|
||||
from copy import deepcopy
|
||||
@ -65,7 +64,8 @@ class Compose:
|
||||
class BaseMixTransform:
|
||||
"""This implementation is from mmyolo"""
|
||||
|
||||
def __init__(self, pre_transform=None, p=0.0) -> None:
|
||||
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
|
||||
self.dataset = dataset
|
||||
self.pre_transform = pre_transform
|
||||
self.p = p
|
||||
|
||||
@ -73,41 +73,28 @@ class BaseMixTransform:
|
||||
if random.uniform(0, 1) > self.p:
|
||||
return labels
|
||||
|
||||
assert "dataset" in labels
|
||||
dataset = labels.pop("dataset")
|
||||
|
||||
# get index of one or three other images
|
||||
indexes = self.get_indexes(dataset)
|
||||
if not isinstance(indexes, collections.abc.Sequence):
|
||||
indexes = self.get_indexes()
|
||||
if isinstance(indexes, int):
|
||||
indexes = [indexes]
|
||||
|
||||
# get images information will be used for Mosaic or MixUp
|
||||
mix_labels = [dataset.get_label_info(index) for index in indexes]
|
||||
mix_labels = [self.dataset.get_label_info(i) for i in indexes]
|
||||
|
||||
if self.pre_transform is not None:
|
||||
for i, data in enumerate(mix_labels):
|
||||
# pre_transform may also require dataset
|
||||
data.update({"dataset": dataset})
|
||||
# before Mosaic or MixUp need to go through
|
||||
# the necessary pre_transform
|
||||
_labels = self.pre_transform(data)
|
||||
_labels.pop("dataset")
|
||||
mix_labels[i] = _labels
|
||||
mix_labels[i] = self.pre_transform(data)
|
||||
labels["mix_labels"] = mix_labels
|
||||
|
||||
# Mosaic or MixUp
|
||||
labels = self._mix_transform(labels)
|
||||
|
||||
if "mix_labels" in labels:
|
||||
labels.pop("mix_labels")
|
||||
labels["dataset"] = dataset
|
||||
|
||||
labels.pop("mix_labels", None)
|
||||
return labels
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_indexes(self, dataset):
|
||||
def get_indexes(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -119,14 +106,15 @@ class Mosaic(BaseMixTransform):
|
||||
Default to (640, 640).
|
||||
"""
|
||||
|
||||
def __init__(self, imgsz=640, p=1.0, border=(0, 0)):
|
||||
def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
|
||||
assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}."
|
||||
super().__init__(pre_transform=None, p=p)
|
||||
super().__init__(dataset=dataset, p=p)
|
||||
self.dataset = dataset
|
||||
self.imgsz = imgsz
|
||||
self.border = border
|
||||
|
||||
def get_indexes(self, dataset):
|
||||
return [random.randint(0, len(dataset) - 1) for _ in range(3)]
|
||||
def get_indexes(self):
|
||||
return [random.randint(0, len(self.dataset) - 1) for _ in range(3)]
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
mosaic_labels = []
|
||||
@ -193,25 +181,19 @@ class Mosaic(BaseMixTransform):
|
||||
|
||||
class MixUp(BaseMixTransform):
|
||||
|
||||
def __init__(self, pre_transform=None, p=0.0) -> None:
|
||||
super().__init__(pre_transform=pre_transform, p=p)
|
||||
def __init__(self, dataset, pre_transform=None, p=0.0) -> None:
|
||||
super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
|
||||
|
||||
def get_indexes(self, dataset):
|
||||
return random.randint(0, len(dataset) - 1)
|
||||
def get_indexes(self):
|
||||
return random.randint(0, len(self.dataset) - 1)
|
||||
|
||||
def _mix_transform(self, labels):
|
||||
im = labels["img"]
|
||||
labels2 = labels["mix_labels"][0]
|
||||
im2 = labels2["img"]
|
||||
cls2 = labels2["cls"]
|
||||
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
|
||||
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
|
||||
im = (im * r + im2 * (1 - r)).astype(np.uint8)
|
||||
cat_instances = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
|
||||
cls = labels["cls"]
|
||||
labels["img"] = im
|
||||
labels["instances"] = cat_instances
|
||||
labels["cls"] = np.concatenate([cls, cls2], 0)
|
||||
labels2 = labels["mix_labels"][0]
|
||||
labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8)
|
||||
labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
|
||||
labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0)
|
||||
return labels
|
||||
|
||||
|
||||
@ -412,7 +394,6 @@ class RandomHSV:
|
||||
|
||||
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
||||
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
|
||||
labels["img"] = img
|
||||
return labels
|
||||
|
||||
|
||||
@ -606,7 +587,6 @@ class Format:
|
||||
self.batch_idx = batch_idx # keep the batch indexes
|
||||
|
||||
def __call__(self, labels):
|
||||
labels.pop("dataset", None)
|
||||
img = labels["img"]
|
||||
h, w = img.shape[:2]
|
||||
cls = labels.pop("cls")
|
||||
@ -656,9 +636,9 @@ class Format:
|
||||
return masks, instances, cls
|
||||
|
||||
|
||||
def mosaic_transforms(imgsz, hyp):
|
||||
def mosaic_transforms(dataset, imgsz, hyp):
|
||||
pre_transform = Compose([
|
||||
Mosaic(imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
|
||||
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
|
||||
CopyPaste(p=hyp.copy_paste),
|
||||
RandomPerspective(
|
||||
degrees=hyp.degrees,
|
||||
@ -670,7 +650,7 @@ def mosaic_transforms(imgsz, hyp):
|
||||
),])
|
||||
return Compose([
|
||||
pre_transform,
|
||||
MixUp(pre_transform=pre_transform, p=hyp.mixup),
|
||||
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
|
||||
Albumentations(p=1.0),
|
||||
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
|
||||
RandomFlip(direction="vertical", p=hyp.flipud),
|
||||
|
@ -42,12 +42,13 @@ class BaseDataset(Dataset):
|
||||
self.imgsz = imgsz
|
||||
self.label_path = label_path
|
||||
self.augment = augment
|
||||
self.single_cls = single_cls
|
||||
self.prefix = prefix
|
||||
|
||||
self.im_files = self.get_img_files(self.img_path)
|
||||
self.labels = self.get_labels()
|
||||
if single_cls:
|
||||
self.update_labels(include_class=[], single_cls=single_cls)
|
||||
if self.single_cls:
|
||||
self.update_labels(include_class=[])
|
||||
|
||||
self.ni = len(self.im_files)
|
||||
|
||||
@ -173,10 +174,7 @@ class BaseDataset(Dataset):
|
||||
self.batch = bi # batch index of image
|
||||
|
||||
def __getitem__(self, index):
|
||||
label = self.get_label_info(index)
|
||||
if self.augment:
|
||||
label["dataset"] = self
|
||||
return self.transforms(label)
|
||||
return self.transforms(self.get_label_info(index))
|
||||
|
||||
def get_label_info(self, index):
|
||||
label = self.labels[index].copy()
|
||||
|
@ -1,7 +1,6 @@
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import Pool
|
||||
from pathlib import Path
|
||||
from typing import OrderedDict
|
||||
|
||||
import torchvision
|
||||
from tqdm import tqdm
|
||||
@ -126,7 +125,7 @@ class YOLODataset(BaseDataset):
|
||||
def build_transforms(self, hyp=None):
|
||||
if self.augment:
|
||||
mosaic = self.augment and not self.rect
|
||||
transforms = mosaic_transforms(self.imgsz, hyp) if mosaic else affine_transforms(self.imgsz, hyp)
|
||||
transforms = mosaic_transforms(self, self.imgsz, hyp) if mosaic else affine_transforms(self.imgsz, hyp)
|
||||
else:
|
||||
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz))])
|
||||
transforms.append(
|
||||
|
@ -72,18 +72,12 @@ weight_decay: 0.0005 # optimizer weight decay 5e-4
|
||||
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
||||
warmup_momentum: 0.8 # warmup initial momentum
|
||||
warmup_bias_lr: 0.1 # warmup initial bias lr
|
||||
box: 0.05 # box loss gain
|
||||
cls: 0.5 # cls loss gain
|
||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||
obj: 1.0 # obj loss gain (scale with pixels)
|
||||
obj_pw: 1.0 # obj BCELoss positive_weight
|
||||
iou_t: 0.20 # IoU training threshold
|
||||
anchor_t: 4.0 # anchor-multiple threshold
|
||||
# anchors: 3 # anchors per output layer (0 to ignore)
|
||||
box: 7.5 # box loss gain
|
||||
cls: 0.5 # cls loss gain (scale with pixels)
|
||||
dfl: 1.5 # dfl loss gain
|
||||
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
|
||||
label_smoothing: 0.0
|
||||
nbs: 64 # nominal batch size
|
||||
# anchors: 3
|
||||
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
|
||||
|
@ -30,8 +30,8 @@ class DetectionTrainer(BaseTrainer):
|
||||
def set_model_attributes(self):
|
||||
nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
|
||||
self.args.box *= 3 / nl # scale to layers
|
||||
self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
||||
self.args.obj *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
||||
self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||
self.model.nc = self.data["nc"] # attach number of classes to model
|
||||
self.model.args = self.args # attach hyperparameters to model
|
||||
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
||||
@ -85,14 +85,11 @@ class Loss:
|
||||
device = next(model.parameters()).device # get model device
|
||||
h = model.args # hyperparameters
|
||||
|
||||
# Define criteria
|
||||
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h["cls_pw"]], device=device), reduction='none')
|
||||
|
||||
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
||||
self.cp, self.cn = smooth_BCE(eps=h.get("label_smoothing", 0.0)) # positive, negative BCE targets
|
||||
|
||||
m = model.model[-1] # Detect() module
|
||||
self.BCEcls = BCEcls
|
||||
self.bce = nn.BCEWithLogitsLoss(reduction='none')
|
||||
self.hyp = h
|
||||
self.stride = m.stride # model strides
|
||||
self.nc = m.nc # number of classes
|
||||
@ -156,7 +153,7 @@ class Loss:
|
||||
|
||||
# cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[1] = self.BCEcls(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# bbox loss
|
||||
if fg_mask.sum():
|
||||
|
Loading…
x
Reference in New Issue
Block a user