From 1197abeb1cee819bef11f34a0539decf019e9a45 Mon Sep 17 00:00:00 2001 From: wa22 Date: Thu, 23 May 2024 05:57:37 +0000 Subject: [PATCH] update --- ultralytics/__init__.py | 3 +- ultralytics/cfg/models/v10/yolov10b.yaml | 40 +++++++ ultralytics/cfg/models/v10/yolov10l.yaml | 40 +++++++ ultralytics/cfg/models/v10/yolov10m.yaml | 43 ++++++++ ultralytics/cfg/models/v10/yolov10n.yaml | 40 +++++++ ultralytics/cfg/models/v10/yolov10s.yaml | 39 +++++++ ultralytics/cfg/models/v10/yolov10x.yaml | 40 +++++++ ultralytics/engine/trainer.py | 3 +- ultralytics/engine/validator.py | 8 +- ultralytics/models/__init__.py | 3 +- ultralytics/models/yolov10/__init__.py | 5 + ultralytics/models/yolov10/model.py | 18 ++++ ultralytics/models/yolov10/predict.py | 37 +++++++ ultralytics/models/yolov10/train.py | 11 ++ ultralytics/models/yolov10/val.py | 29 +++++ ultralytics/nn/modules/__init__.py | 11 +- ultralytics/nn/modules/block.py | 129 +++++++++++++++++++++++ ultralytics/nn/modules/head.py | 59 +++++++++-- ultralytics/nn/tasks.py | 33 ++++-- ultralytics/utils/loss.py | 16 ++- ultralytics/utils/tal.py | 3 +- ultralytics/utils/torch_utils.py | 9 +- 22 files changed, 594 insertions(+), 25 deletions(-) create mode 100644 ultralytics/cfg/models/v10/yolov10b.yaml create mode 100644 ultralytics/cfg/models/v10/yolov10l.yaml create mode 100644 ultralytics/cfg/models/v10/yolov10m.yaml create mode 100644 ultralytics/cfg/models/v10/yolov10n.yaml create mode 100644 ultralytics/cfg/models/v10/yolov10s.yaml create mode 100644 ultralytics/cfg/models/v10/yolov10x.yaml create mode 100644 ultralytics/models/yolov10/__init__.py create mode 100644 ultralytics/models/yolov10/model.py create mode 100644 ultralytics/models/yolov10/predict.py create mode 100644 ultralytics/models/yolov10/train.py create mode 100644 ultralytics/models/yolov10/val.py diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index d25836a0..8ff1b4fb 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -3,7 +3,7 @@ __version__ = "8.1.34" from ultralytics.data.explorer.explorer import Explorer -from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld +from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld, YOLOv10 from ultralytics.models.fastsam import FastSAM from ultralytics.models.nas import NAS from ultralytics.utils import ASSETS, SETTINGS as settings @@ -23,4 +23,5 @@ __all__ = ( "download", "settings", "Explorer", + "YOLOv10" ) diff --git a/ultralytics/cfg/models/v10/yolov10b.yaml b/ultralytics/cfg/models/v10/yolov10b.yaml new file mode 100644 index 00000000..a9dc7218 --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10b.yaml @@ -0,0 +1,40 @@ +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + b: [0.67, 1.00, 512] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fCIB, [512, True]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10l.yaml b/ultralytics/cfg/models/v10/yolov10l.yaml new file mode 100644 index 00000000..047de262 --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10l.yaml @@ -0,0 +1,40 @@ +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fCIB, [512, True]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10m.yaml b/ultralytics/cfg/models/v10/yolov10m.yaml new file mode 100644 index 00000000..5bdb5bf5 --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10m.yaml @@ -0,0 +1,43 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10n.yaml b/ultralytics/cfg/models/v10/yolov10n.yaml new file mode 100644 index 00000000..1ee7437e --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10n.yaml @@ -0,0 +1,40 @@ +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10s.yaml b/ultralytics/cfg/models/v10/yolov10s.yaml new file mode 100644 index 00000000..c61e08cd --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10s.yaml @@ -0,0 +1,39 @@ +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + s: [0.33, 0.50, 1024] + +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10x.yaml b/ultralytics/cfg/models/v10/yolov10x.yaml new file mode 100644 index 00000000..ab5fc8f0 --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10x.yaml @@ -0,0 +1,40 @@ +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + x: [1.00, 1.25, 512] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2fCIB, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fCIB, [512, True]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 29d6f1d0..841ec120 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -425,7 +425,8 @@ class BaseTrainer: self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"]) # Validation - if self.args.val or final_epoch or self.stopper.possible_stop or self.stop: + if (self.args.val and (((epoch+1) % 10 == 0) or (self.epochs - epoch) <= 10)) \ + or final_epoch or self.stopper.possible_stop or self.stop: self.metrics, self.fitness = self.validate() self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr}) self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index 17666e38..aa329a41 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -196,10 +196,16 @@ class BaseValidator: self.check_stats(stats) self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt))) self.finalize_metrics() - self.print_results() + # self.print_results() self.run_callbacks("on_val_end") if self.training: model.float() + assert(self.args.save_json and self.jdict) + with open(str(self.save_dir / "predictions.json"), "w") as f: + LOGGER.info(f"Saving {f.name}...") + json.dump(self.jdict, f) # flatten and save + stats = self.eval_json(stats) # update stats + stats['fitness'] = stats['metrics/mAP50-95(B)'] results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats else: diff --git a/ultralytics/models/__init__.py b/ultralytics/models/__init__.py index b9b6eb35..42de3fba 100644 --- a/ultralytics/models/__init__.py +++ b/ultralytics/models/__init__.py @@ -3,5 +3,6 @@ from .rtdetr import RTDETR from .sam import SAM from .yolo import YOLO, YOLOWorld +from .yolov10 import YOLOv10 -__all__ = "YOLO", "RTDETR", "SAM", "YOLOWorld" # allow simpler import +__all__ = "YOLO", "RTDETR", "SAM", "YOLOWorld", "YOLOv10" # allow simpler import diff --git a/ultralytics/models/yolov10/__init__.py b/ultralytics/models/yolov10/__init__.py new file mode 100644 index 00000000..97f137f9 --- /dev/null +++ b/ultralytics/models/yolov10/__init__.py @@ -0,0 +1,5 @@ +from .model import YOLOv10 +from .predict import YOLOv10DetectionPredictor +from .val import YOLOv10DetectionValidator + +__all__ = "YOLOv10DetectionPredictor", "YOLOv10DetectionValidator", "YOLOv10" diff --git a/ultralytics/models/yolov10/model.py b/ultralytics/models/yolov10/model.py new file mode 100644 index 00000000..e1c3e28c --- /dev/null +++ b/ultralytics/models/yolov10/model.py @@ -0,0 +1,18 @@ +from ..yolo import YOLO +from ultralytics.nn.tasks import YOLOv10DetectionModel +from .val import YOLOv10DetectionValidator +from .predict import YOLOv10DetectionPredictor +from .train import YOLOv10DetectionTrainer + +class YOLOv10(YOLO): + @property + def task_map(self): + """Map head to model, trainer, validator, and predictor classes.""" + return { + "detect": { + "model": YOLOv10DetectionModel, + "trainer": YOLOv10DetectionTrainer, + "validator": YOLOv10DetectionValidator, + "predictor": YOLOv10DetectionPredictor, + }, + } \ No newline at end of file diff --git a/ultralytics/models/yolov10/predict.py b/ultralytics/models/yolov10/predict.py new file mode 100644 index 00000000..ad1a7542 --- /dev/null +++ b/ultralytics/models/yolov10/predict.py @@ -0,0 +1,37 @@ +from ultralytics.models.yolo.detect import DetectionPredictor +import torch +from ultralytics.utils import ops +from ultralytics.engine.results import Results + + +class YOLOv10DetectionPredictor(DetectionPredictor): + def postprocess(self, preds, img, orig_imgs): + if not isinstance(preds, (list, tuple)): + preds = [preds, None] + + prediction = preds[0].transpose(-1, -2) + _, _, nd = prediction.shape + nc = nd - 4 + bboxes, scores = prediction.split((4, nd-4), dim=-1) + bboxes = ops.xywh2xyxy(bboxes) + + scores, index = torch.topk(scores.flatten(1), self.args.max_det, axis=-1) + labels = index % nc + index = torch.div(index, nc, rounding_mode='floor') + bboxes = bboxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bboxes.shape[-1])) + + preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) + assert(preds.shape[0] == 1) + mask = preds[..., 4] > self.args.conf + preds = preds[mask].unsqueeze(0) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for i, pred in enumerate(preds): + orig_img = orig_imgs[i] + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + img_path = self.batch[0][i] + results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred)) + return results diff --git a/ultralytics/models/yolov10/train.py b/ultralytics/models/yolov10/train.py new file mode 100644 index 00000000..66b8d71c --- /dev/null +++ b/ultralytics/models/yolov10/train.py @@ -0,0 +1,11 @@ +from ultralytics.models.yolo.detect import DetectionTrainer +from .val import YOLOv10DetectionValidator +from copy import copy + +class YOLOv10DetectionTrainer(DetectionTrainer): + def get_validator(self): + """Returns a DetectionValidator for YOLO model validation.""" + self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", + return YOLOv10DetectionValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) diff --git a/ultralytics/models/yolov10/val.py b/ultralytics/models/yolov10/val.py new file mode 100644 index 00000000..0993681c --- /dev/null +++ b/ultralytics/models/yolov10/val.py @@ -0,0 +1,29 @@ +from ultralytics.models.yolo.detect import DetectionValidator +from ultralytics.utils import ops +import torch + +class YOLOv10DetectionValidator(DetectionValidator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.args.save_json |= self.is_coco + + def postprocess(self, preds): + if self.training: + preds = preds["one2one"] + + if not isinstance(preds, (list, tuple)): + preds = [preds, None] + + prediction = preds[0].transpose(-1, -2) + _, _, nd = prediction.shape + nc = nd - 4 + assert(self.nc == nc) + bboxes, scores = prediction.split((4, nd-4), dim=-1) + bboxes = ops.xywh2xyxy(bboxes) + + scores, index = torch.topk(scores.flatten(1), self.args.max_det, axis=-1) + labels = index % self.nc + index = torch.div(index, self.nc, rounding_mode='floor') + bboxes = bboxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bboxes.shape[-1])) + + return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) \ No newline at end of file diff --git a/ultralytics/nn/modules/__init__.py b/ultralytics/nn/modules/__init__.py index d785c008..4a99bf59 100644 --- a/ultralytics/nn/modules/__init__.py +++ b/ultralytics/nn/modules/__init__.py @@ -46,6 +46,10 @@ from .block import ( CBFuse, CBLinear, Silence, + PSA, + C2fCIB, + SCDown, + RepVGGDW ) from .conv import ( CBAM, @@ -62,7 +66,7 @@ from .conv import ( RepConv, SpatialAttention, ) -from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect +from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect, v10Detect from .transformer import ( AIFI, MLP, @@ -135,4 +139,9 @@ __all__ = ( "CBFuse", "CBLinear", "Silence", + "PSA", + "C2fCIB", + "SCDown", + "RepVGGDW", + "v10Detect" ) diff --git a/ultralytics/nn/modules/block.py b/ultralytics/nn/modules/block.py index a263b603..b76a9595 100644 --- a/ultralytics/nn/modules/block.py +++ b/ultralytics/nn/modules/block.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad from .transformer import TransformerBlock +from ultralytics.utils.torch_utils import fuse_conv_and_bn __all__ = ( "DFL", @@ -696,3 +697,131 @@ class CBFuse(nn.Module): res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])] out = torch.sum(torch.stack(res + xs[-1:]), dim=0) return out + + +class RepVGGDW(torch.nn.Module): + def __init__(self, ed) -> None: + super().__init__() + self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False) + self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False) + self.dim = ed + self.act = nn.SiLU() + + def forward(self, x): + return self.act(self.conv(x) + self.conv1(x)) + + def forward_fuse(self, x): + return self.act(self.conv(x)) + + @torch.no_grad() + def fuse(self): + conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn) + conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn) + + conv_w = conv.weight + conv_b = conv.bias + conv1_w = conv1.weight + conv1_b = conv1.bias + + conv1_w = torch.nn.functional.pad(conv1_w, [2,2,2,2]) + + final_conv_w = conv_w + conv1_w + final_conv_b = conv_b + conv1_b + + conv.weight.data.copy_(final_conv_w) + conv.bias.data.copy_(final_conv_b) + + self.conv = conv + del self.conv1 + +class CIB(nn.Module): + """Standard bottleneck.""" + + def __init__(self, c1, c2, shortcut=True, e=0.5, lk=False): + """Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and + expansion. + """ + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = nn.Sequential( + Conv(c1, c1, 3, g=c1), + Conv(c1, 2 * c_, 1), + Conv(2 * c_, 2 * c_, 3, g=2 * c_) if not lk else RepVGGDW(2 * c_), + Conv(2 * c_, c2, 1), + Conv(c2, c2, 3, g=c2), + ) + + self.add = shortcut and c1 == c2 + + def forward(self, x): + """'forward()' applies the YOLO FPN to input data.""" + return x + self.cv1(x) if self.add else self.cv1(x) + +class C2fCIB(C2f): + """Faster Implementation of CSP Bottleneck with 2 convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5): + """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups, + expansion. + """ + super().__init__(c1, c2, n, shortcut, g, e) + self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n)) + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, + attn_ratio=0.5): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.key_dim = int(self.head_dim * attn_ratio) + self.scale = self.key_dim ** -0.5 + nh_kd = nh_kd = self.key_dim * num_heads + h = dim + nh_kd * 2 + self.qkv = Conv(dim, h, 1, act=False) + self.proj = Conv(dim, dim, 1, act=False) + self.pe = Conv(dim, dim, 3, 1, g=dim, act=False) + + def forward(self, x): + B, _, H, W = x.shape + N = H * W + qkv = self.qkv(x) + q, k, v = qkv.view(B, self.num_heads, -1, N).split([self.key_dim, self.key_dim, self.head_dim], dim=2) + + attn = ( + (q.transpose(-2, -1) @ k) * self.scale + ) + attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + self.pe(v.reshape(B, -1, H, W)) + x = self.proj(x) + return x + +class PSA(nn.Module): + + def __init__(self, c1, c2, e=0.5): + super().__init__() + assert(c1 == c2) + self.c = int(c1 * e) + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv(2 * self.c, c1, 1) + + self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64) + self.ffn = nn.Sequential( + Conv(self.c, self.c*2, 1), + Conv(self.c*2, self.c, 1, act=False) + ) + + def forward(self, x): + a, b = self.cv1(x).split((self.c, self.c), dim=1) + b = b + self.attn(b) + b = b + self.ffn(b) + return self.cv2(torch.cat((a, b), 1)) + +class SCDown(nn.Module): + def __init__(self, c1, c2, k, s): + super().__init__() + self.cv1 = Conv(c1, c2, 1, 1) + self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False) + + def forward(self, x): + return self.cv2(self.cv1(x)) \ No newline at end of file diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 9cd794e4..a94ceedb 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -12,6 +12,7 @@ from .block import DFL, Proto, ContrastiveHead, BNContrastiveHead from .conv import Conv from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer from .utils import bias_init_with_prob, linear_init +import copy __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder" @@ -40,17 +41,17 @@ class Detect(nn.Module): self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() - def forward(self, x): - """Concatenates and returns predicted bounding boxes and class probabilities.""" - for i in range(self.nl): - x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) - if self.training: # Training path - return x + def generate_static_anchors(self, x): + shape = x[0].shape + self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) + self.shape = shape + def inference(self, x): # Inference path shape = x[0].shape # BCHW x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) if self.dynamic or self.shape != shape: + assert(not self.export) self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) self.shape = shape @@ -74,6 +75,21 @@ class Detect(nn.Module): y = torch.cat((dbox, cls.sigmoid()), 1) return y if self.export else (y, x) + def forward_feat(self, x, cv2, cv3): + y = [] + for i in range(self.nl): + y.append(torch.cat((cv2[i](x[i]), cv3[i](x[i])), 1)) + return y + + def forward(self, x): + """Concatenates and returns predicted bounding boxes and class probabilities.""" + y = self.forward_feat(x, self.cv2, self.cv3) + + if self.training: + return y + + return self.inference(y) + def bias_init(self): """Initialize Detect() biases, WARNING: requires stride availability.""" m = self # self.model[-1] # Detect() module @@ -480,3 +496,34 @@ class RTDETRDecoder(nn.Module): xavier_uniform_(self.query_pos_head.layers[1].weight) for layer in self.input_proj: xavier_uniform_(layer[0].weight) + +class v10Detect(Detect): + + def __init__(self, nc=80, ch=()): + super().__init__(nc, ch) + c3 = max(ch[0], min(self.nc, 100)) # channels + self.cv3 = nn.ModuleList(nn.Sequential(nn.Sequential(Conv(x, x, 3, g=x), Conv(x, c3, 1)), \ + nn.Sequential(Conv(c3, c3, 3, g=c3), Conv(c3, c3, 1)), \ + nn.Conv2d(c3, self.nc, 1)) for i, x in enumerate(ch)) + + self.one2one_cv2 = copy.deepcopy(self.cv2) + self.one2one_cv3 = copy.deepcopy(self.cv3) + + def forward(self, x): + one2one = self.forward_feat([xi.detach() for xi in x], self.one2one_cv2, self.one2one_cv3) + if not self.training: + one2one = self.inference(one2one) + return one2one + else: + one2many = super().forward(x) + return {"one2many": one2many, "one2one": one2one} + + def bias_init(self): + super().bias_init() + """Initialize Detect() biases, WARNING: requires stride availability.""" + m = self # self.model[-1] # Detect() module + # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1 + # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency + for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from + a[-1].bias.data[:] = 1.0 # box + b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index f116ed2c..99b21525 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -49,10 +49,15 @@ from ultralytics.nn.modules import ( CBFuse, CBLinear, Silence, + C2fCIB, + PSA, + SCDown, + RepVGGDW, + v10Detect ) from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml -from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss +from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss from ultralytics.utils.plotting import feature_visualization from ultralytics.utils.torch_utils import ( fuse_conv_and_bn, @@ -191,6 +196,9 @@ class BaseModel(nn.Module): if isinstance(m, RepConv): m.fuse_convs() m.forward = m.forward_fuse # update forward + if isinstance(m, RepVGGDW): + m.fuse() + m.forward = m.forward_fuse self.info(verbose=verbose) return self @@ -294,6 +302,8 @@ class DetectionModel(BaseModel): s = 256 # 2x min stride m.inplace = self.inplace forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x) + if isinstance(m, v10Detect): + forward = lambda x: self.forward(x)["one2many"] m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward self.stride = m.stride m.bias_init() # only run once @@ -627,6 +637,9 @@ class WorldModel(DetectionModel): return torch.unbind(torch.cat(embeddings, 1), dim=0) return x +class YOLOv10DetectionModel(DetectionModel): + def init_criterion(self): + return v10DetectLoss(self) class Ensemble(nn.ModuleList): """Ensemble of models.""" @@ -869,6 +882,9 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) DWConvTranspose2d, C3x, RepC3, + PSA, + SCDown, + C2fCIB }: c1, c2 = ch[f], args[0] if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output) @@ -880,7 +896,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) ) # num heads args = [c1, c2, *args[1:]] - if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3): + if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fCIB): args.insert(2, n) # number of repeats n = 1 elif m is AIFI: @@ -897,7 +913,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) args = [ch[f]] elif m is Concat: c2 = sum(ch[x] for x in f) - elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn}: + elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}: args.append([ch[x] for x in f]) if m is Segment: args[2] = make_divisible(min(args[2], max_channels) * width, 8) @@ -936,7 +952,10 @@ def yaml_model_load(path): LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.") path = path.with_name(new_stem + path.suffix) - unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml + if "v10" not in str(path): + unified_path = re.sub(r"(\d+)([nsblmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml + else: + unified_path = path yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path) d = yaml_load(yaml_file) # model dict d["scale"] = guess_model_scale(path) @@ -959,7 +978,7 @@ def guess_model_scale(model_path): with contextlib.suppress(AttributeError): import re - return re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x + return re.search(r"yolov\d+([nsblmx])", Path(model_path).stem).group(1) # n, s, m, l, or x return "" @@ -982,7 +1001,7 @@ def guess_model_task(model): m = cfg["head"][-1][-2].lower() # output module name if m in {"classify", "classifier", "cls", "fc"}: return "classify" - if m == "detect": + if m == "detect" or m == "v10detect": return "detect" if m == "segment": return "segment" @@ -1014,7 +1033,7 @@ def guess_model_task(model): return "pose" elif isinstance(m, OBB): return "obb" - elif isinstance(m, (Detect, WorldDetect)): + elif isinstance(m, (Detect, WorldDetect, v10Detect)): return "detect" # Guess from model filename diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index 360a292a..d0ca9c39 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -147,7 +147,7 @@ class KeypointLoss(nn.Module): class v8DetectionLoss: """Criterion class for computing training losses.""" - def __init__(self, model): # model must be de-paralleled + def __init__(self, model, tal_topk=10): # model must be de-paralleled """Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function.""" device = next(model.parameters()).device # get model device h = model.args # hyperparameters @@ -163,7 +163,7 @@ class v8DetectionLoss: self.use_dfl = m.reg_max > 1 - self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) + self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0) self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device) self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device) @@ -713,3 +713,15 @@ class v8OBBLoss(v8DetectionLoss): b, a, c = pred_dist.shape # batch, anchors, channels pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1) + +class v10DetectLoss: + def __init__(self, model): + self.one2many = v8DetectionLoss(model, tal_topk=10) + self.one2one = v8DetectionLoss(model, tal_topk=1) + + def __call__(self, preds, batch): + one2many = preds["one2many"] + loss_one2many = self.one2many(one2many, batch) + one2one = preds["one2one"] + loss_one2one = self.one2one(one2one, batch) + return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1])) diff --git a/ultralytics/utils/tal.py b/ultralytics/utils/tal.py index 9cee0500..b11c2b2c 100644 --- a/ultralytics/utils/tal.py +++ b/ultralytics/utils/tal.py @@ -308,7 +308,8 @@ def make_anchors(feats, strides, grid_cell_offset=0.5): def dist2bbox(distance, anchor_points, xywh=True, dim=-1): """Transform distance(ltrb) to box(xywh or xyxy).""" - lt, rb = distance.chunk(2, dim) + assert(distance.shape[dim] == 4) + lt, rb = distance.split([2, 2], dim) x1y1 = anchor_points - lt x2y2 = anchor_points + rb if xywh: diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 77d8cc8c..d476e1f8 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -310,10 +310,11 @@ def get_flops(model, imgsz=640): imgsz = [imgsz, imgsz] # expand if int/float try: # Use stride size for input tensor - stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride - im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format - flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs - return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs + # stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride + # im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format + # flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs + # return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs + raise Exception except Exception: # Use actual image size for input tensor (i.e. required for RTDETR models) im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format