This commit is contained in:
wa22 2024-05-23 05:57:37 +00:00
parent 5be2ffbd13
commit 1197abeb1c
22 changed files with 594 additions and 25 deletions

View File

@ -3,7 +3,7 @@
__version__ = "8.1.34" __version__ = "8.1.34"
from ultralytics.data.explorer.explorer import Explorer from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld, YOLOv10
from ultralytics.models.fastsam import FastSAM from ultralytics.models.fastsam import FastSAM
from ultralytics.models.nas import NAS from ultralytics.models.nas import NAS
from ultralytics.utils import ASSETS, SETTINGS as settings from ultralytics.utils import ASSETS, SETTINGS as settings
@ -23,4 +23,5 @@ __all__ = (
"download", "download",
"settings", "settings",
"Explorer", "Explorer",
"YOLOv10"
) )

View File

@ -0,0 +1,40 @@
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
b: [0.67, 1.00, 512]
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2fCIB, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, PSA, [1024]] # 10
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2fCIB, [512, True]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)
- [-1, 1, SCDown, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)
- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

View File

@ -0,0 +1,40 @@
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2fCIB, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, PSA, [1024]] # 10
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2fCIB, [512, True]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)
- [-1, 1, SCDown, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)
- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

View File

@ -0,0 +1,43 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2fCIB, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, PSA, [1024]] # 10
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)
- [-1, 1, SCDown, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)
- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

View File

@ -0,0 +1,40 @@
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024]
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, PSA, [1024]] # 10
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 19 (P4/16-medium)
- [-1, 1, SCDown, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large)
- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

View File

@ -0,0 +1,39 @@
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
s: [0.33, 0.50, 1024]
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2fCIB, [1024, True, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, PSA, [1024]] # 10
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 19 (P4/16-medium)
- [-1, 1, SCDown, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large)
- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

View File

@ -0,0 +1,40 @@
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
x: [1.00, 1.25, 512]
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2fCIB, [512, True]]
- [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2fCIB, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, PSA, [1024]] # 10
# YOLOv8.0n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2fCIB, [512, True]] # 13
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium)
- [-1, 1, SCDown, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large)
- [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

View File

@ -425,7 +425,8 @@ class BaseTrainer:
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"]) self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
# Validation # Validation
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop: if (self.args.val and (((epoch+1) % 10 == 0) or (self.epochs - epoch) <= 10)) \
or final_epoch or self.stopper.possible_stop or self.stop:
self.metrics, self.fitness = self.validate() self.metrics, self.fitness = self.validate()
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr}) self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch

View File

@ -196,10 +196,16 @@ class BaseValidator:
self.check_stats(stats) self.check_stats(stats)
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt))) self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
self.finalize_metrics() self.finalize_metrics()
self.print_results() # self.print_results()
self.run_callbacks("on_val_end") self.run_callbacks("on_val_end")
if self.training: if self.training:
model.float() model.float()
assert(self.args.save_json and self.jdict)
with open(str(self.save_dir / "predictions.json"), "w") as f:
LOGGER.info(f"Saving {f.name}...")
json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats
stats['fitness'] = stats['metrics/mAP50-95(B)']
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
else: else:

View File

@ -3,5 +3,6 @@
from .rtdetr import RTDETR from .rtdetr import RTDETR
from .sam import SAM from .sam import SAM
from .yolo import YOLO, YOLOWorld from .yolo import YOLO, YOLOWorld
from .yolov10 import YOLOv10
__all__ = "YOLO", "RTDETR", "SAM", "YOLOWorld" # allow simpler import __all__ = "YOLO", "RTDETR", "SAM", "YOLOWorld", "YOLOv10" # allow simpler import

View File

@ -0,0 +1,5 @@
from .model import YOLOv10
from .predict import YOLOv10DetectionPredictor
from .val import YOLOv10DetectionValidator
__all__ = "YOLOv10DetectionPredictor", "YOLOv10DetectionValidator", "YOLOv10"

View File

@ -0,0 +1,18 @@
from ..yolo import YOLO
from ultralytics.nn.tasks import YOLOv10DetectionModel
from .val import YOLOv10DetectionValidator
from .predict import YOLOv10DetectionPredictor
from .train import YOLOv10DetectionTrainer
class YOLOv10(YOLO):
@property
def task_map(self):
"""Map head to model, trainer, validator, and predictor classes."""
return {
"detect": {
"model": YOLOv10DetectionModel,
"trainer": YOLOv10DetectionTrainer,
"validator": YOLOv10DetectionValidator,
"predictor": YOLOv10DetectionPredictor,
},
}

View File

@ -0,0 +1,37 @@
from ultralytics.models.yolo.detect import DetectionPredictor
import torch
from ultralytics.utils import ops
from ultralytics.engine.results import Results
class YOLOv10DetectionPredictor(DetectionPredictor):
def postprocess(self, preds, img, orig_imgs):
if not isinstance(preds, (list, tuple)):
preds = [preds, None]
prediction = preds[0].transpose(-1, -2)
_, _, nd = prediction.shape
nc = nd - 4
bboxes, scores = prediction.split((4, nd-4), dim=-1)
bboxes = ops.xywh2xyxy(bboxes)
scores, index = torch.topk(scores.flatten(1), self.args.max_det, axis=-1)
labels = index % nc
index = torch.div(index, nc, rounding_mode='floor')
bboxes = bboxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bboxes.shape[-1]))
preds = torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
assert(preds.shape[0] == 1)
mask = preds[..., 4] > self.args.conf
preds = preds[mask].unsqueeze(0)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i]
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
img_path = self.batch[0][i]
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
return results

View File

@ -0,0 +1,11 @@
from ultralytics.models.yolo.detect import DetectionTrainer
from .val import YOLOv10DetectionValidator
from copy import copy
class YOLOv10DetectionTrainer(DetectionTrainer):
def get_validator(self):
"""Returns a DetectionValidator for YOLO model validation."""
self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo",
return YOLOv10DetectionValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)

View File

@ -0,0 +1,29 @@
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import ops
import torch
class YOLOv10DetectionValidator(DetectionValidator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.args.save_json |= self.is_coco
def postprocess(self, preds):
if self.training:
preds = preds["one2one"]
if not isinstance(preds, (list, tuple)):
preds = [preds, None]
prediction = preds[0].transpose(-1, -2)
_, _, nd = prediction.shape
nc = nd - 4
assert(self.nc == nc)
bboxes, scores = prediction.split((4, nd-4), dim=-1)
bboxes = ops.xywh2xyxy(bboxes)
scores, index = torch.topk(scores.flatten(1), self.args.max_det, axis=-1)
labels = index % self.nc
index = torch.div(index, self.nc, rounding_mode='floor')
bboxes = bboxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bboxes.shape[-1]))
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)

View File

@ -46,6 +46,10 @@ from .block import (
CBFuse, CBFuse,
CBLinear, CBLinear,
Silence, Silence,
PSA,
C2fCIB,
SCDown,
RepVGGDW
) )
from .conv import ( from .conv import (
CBAM, CBAM,
@ -62,7 +66,7 @@ from .conv import (
RepConv, RepConv,
SpatialAttention, SpatialAttention,
) )
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect, v10Detect
from .transformer import ( from .transformer import (
AIFI, AIFI,
MLP, MLP,
@ -135,4 +139,9 @@ __all__ = (
"CBFuse", "CBFuse",
"CBLinear", "CBLinear",
"Silence", "Silence",
"PSA",
"C2fCIB",
"SCDown",
"RepVGGDW",
"v10Detect"
) )

View File

@ -7,6 +7,7 @@ import torch.nn.functional as F
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad
from .transformer import TransformerBlock from .transformer import TransformerBlock
from ultralytics.utils.torch_utils import fuse_conv_and_bn
__all__ = ( __all__ = (
"DFL", "DFL",
@ -696,3 +697,131 @@ class CBFuse(nn.Module):
res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])] res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])]
out = torch.sum(torch.stack(res + xs[-1:]), dim=0) out = torch.sum(torch.stack(res + xs[-1:]), dim=0)
return out return out
class RepVGGDW(torch.nn.Module):
def __init__(self, ed) -> None:
super().__init__()
self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False)
self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False)
self.dim = ed
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.conv(x) + self.conv1(x))
def forward_fuse(self, x):
return self.act(self.conv(x))
@torch.no_grad()
def fuse(self):
conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn)
conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn)
conv_w = conv.weight
conv_b = conv.bias
conv1_w = conv1.weight
conv1_b = conv1.bias
conv1_w = torch.nn.functional.pad(conv1_w, [2,2,2,2])
final_conv_w = conv_w + conv1_w
final_conv_b = conv_b + conv1_b
conv.weight.data.copy_(final_conv_w)
conv.bias.data.copy_(final_conv_b)
self.conv = conv
del self.conv1
class CIB(nn.Module):
"""Standard bottleneck."""
def __init__(self, c1, c2, shortcut=True, e=0.5, lk=False):
"""Initializes a bottleneck module with given input/output channels, shortcut option, group, kernels, and
expansion.
"""
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = nn.Sequential(
Conv(c1, c1, 3, g=c1),
Conv(c1, 2 * c_, 1),
Conv(2 * c_, 2 * c_, 3, g=2 * c_) if not lk else RepVGGDW(2 * c_),
Conv(2 * c_, c2, 1),
Conv(c2, c2, 3, g=c2),
)
self.add = shortcut and c1 == c2
def forward(self, x):
"""'forward()' applies the YOLO FPN to input data."""
return x + self.cv1(x) if self.add else self.cv1(x)
class C2fCIB(C2f):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5):
"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
expansion.
"""
super().__init__(c1, c2, n, shortcut, g, e)
self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n))
class Attention(nn.Module):
def __init__(self, dim, num_heads=8,
attn_ratio=0.5):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.key_dim = int(self.head_dim * attn_ratio)
self.scale = self.key_dim ** -0.5
nh_kd = nh_kd = self.key_dim * num_heads
h = dim + nh_kd * 2
self.qkv = Conv(dim, h, 1, act=False)
self.proj = Conv(dim, dim, 1, act=False)
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
def forward(self, x):
B, _, H, W = x.shape
N = H * W
qkv = self.qkv(x)
q, k, v = qkv.view(B, self.num_heads, -1, N).split([self.key_dim, self.key_dim, self.head_dim], dim=2)
attn = (
(q.transpose(-2, -1) @ k) * self.scale
)
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + self.pe(v.reshape(B, -1, H, W))
x = self.proj(x)
return x
class PSA(nn.Module):
def __init__(self, c1, c2, e=0.5):
super().__init__()
assert(c1 == c2)
self.c = int(c1 * e)
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv(2 * self.c, c1, 1)
self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64)
self.ffn = nn.Sequential(
Conv(self.c, self.c*2, 1),
Conv(self.c*2, self.c, 1, act=False)
)
def forward(self, x):
a, b = self.cv1(x).split((self.c, self.c), dim=1)
b = b + self.attn(b)
b = b + self.ffn(b)
return self.cv2(torch.cat((a, b), 1))
class SCDown(nn.Module):
def __init__(self, c1, c2, k, s):
super().__init__()
self.cv1 = Conv(c1, c2, 1, 1)
self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)
def forward(self, x):
return self.cv2(self.cv1(x))

View File

@ -12,6 +12,7 @@ from .block import DFL, Proto, ContrastiveHead, BNContrastiveHead
from .conv import Conv from .conv import Conv
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
from .utils import bias_init_with_prob, linear_init from .utils import bias_init_with_prob, linear_init
import copy
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder" __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"
@ -40,17 +41,17 @@ class Detect(nn.Module):
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
def forward(self, x): def generate_static_anchors(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities.""" shape = x[0].shape
for i in range(self.nl): self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) self.shape = shape
if self.training: # Training path
return x
def inference(self, x):
# Inference path # Inference path
shape = x[0].shape # BCHW shape = x[0].shape # BCHW
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.dynamic or self.shape != shape: if self.dynamic or self.shape != shape:
assert(not self.export)
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape self.shape = shape
@ -74,6 +75,21 @@ class Detect(nn.Module):
y = torch.cat((dbox, cls.sigmoid()), 1) y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x) return y if self.export else (y, x)
def forward_feat(self, x, cv2, cv3):
y = []
for i in range(self.nl):
y.append(torch.cat((cv2[i](x[i]), cv3[i](x[i])), 1))
return y
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
y = self.forward_feat(x, self.cv2, self.cv3)
if self.training:
return y
return self.inference(y)
def bias_init(self): def bias_init(self):
"""Initialize Detect() biases, WARNING: requires stride availability.""" """Initialize Detect() biases, WARNING: requires stride availability."""
m = self # self.model[-1] # Detect() module m = self # self.model[-1] # Detect() module
@ -480,3 +496,34 @@ class RTDETRDecoder(nn.Module):
xavier_uniform_(self.query_pos_head.layers[1].weight) xavier_uniform_(self.query_pos_head.layers[1].weight)
for layer in self.input_proj: for layer in self.input_proj:
xavier_uniform_(layer[0].weight) xavier_uniform_(layer[0].weight)
class v10Detect(Detect):
def __init__(self, nc=80, ch=()):
super().__init__(nc, ch)
c3 = max(ch[0], min(self.nc, 100)) # channels
self.cv3 = nn.ModuleList(nn.Sequential(nn.Sequential(Conv(x, x, 3, g=x), Conv(x, c3, 1)), \
nn.Sequential(Conv(c3, c3, 3, g=c3), Conv(c3, c3, 1)), \
nn.Conv2d(c3, self.nc, 1)) for i, x in enumerate(ch))
self.one2one_cv2 = copy.deepcopy(self.cv2)
self.one2one_cv3 = copy.deepcopy(self.cv3)
def forward(self, x):
one2one = self.forward_feat([xi.detach() for xi in x], self.one2one_cv2, self.one2one_cv3)
if not self.training:
one2one = self.inference(one2one)
return one2one
else:
one2many = super().forward(x)
return {"one2many": one2many, "one2one": one2one}
def bias_init(self):
super().bias_init()
"""Initialize Detect() biases, WARNING: requires stride availability."""
m = self # self.model[-1] # Detect() module
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)

View File

@ -49,10 +49,15 @@ from ultralytics.nn.modules import (
CBFuse, CBFuse,
CBLinear, CBLinear,
Silence, Silence,
C2fCIB,
PSA,
SCDown,
RepVGGDW,
v10Detect
) )
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss
from ultralytics.utils.plotting import feature_visualization from ultralytics.utils.plotting import feature_visualization
from ultralytics.utils.torch_utils import ( from ultralytics.utils.torch_utils import (
fuse_conv_and_bn, fuse_conv_and_bn,
@ -191,6 +196,9 @@ class BaseModel(nn.Module):
if isinstance(m, RepConv): if isinstance(m, RepConv):
m.fuse_convs() m.fuse_convs()
m.forward = m.forward_fuse # update forward m.forward = m.forward_fuse # update forward
if isinstance(m, RepVGGDW):
m.fuse()
m.forward = m.forward_fuse
self.info(verbose=verbose) self.info(verbose=verbose)
return self return self
@ -294,6 +302,8 @@ class DetectionModel(BaseModel):
s = 256 # 2x min stride s = 256 # 2x min stride
m.inplace = self.inplace m.inplace = self.inplace
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x) forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
if isinstance(m, v10Detect):
forward = lambda x: self.forward(x)["one2many"]
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
self.stride = m.stride self.stride = m.stride
m.bias_init() # only run once m.bias_init() # only run once
@ -627,6 +637,9 @@ class WorldModel(DetectionModel):
return torch.unbind(torch.cat(embeddings, 1), dim=0) return torch.unbind(torch.cat(embeddings, 1), dim=0)
return x return x
class YOLOv10DetectionModel(DetectionModel):
def init_criterion(self):
return v10DetectLoss(self)
class Ensemble(nn.ModuleList): class Ensemble(nn.ModuleList):
"""Ensemble of models.""" """Ensemble of models."""
@ -869,6 +882,9 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
DWConvTranspose2d, DWConvTranspose2d,
C3x, C3x,
RepC3, RepC3,
PSA,
SCDown,
C2fCIB
}: }:
c1, c2 = ch[f], args[0] c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output) if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
@ -880,7 +896,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
) # num heads ) # num heads
args = [c1, c2, *args[1:]] args = [c1, c2, *args[1:]]
if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3): if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fCIB):
args.insert(2, n) # number of repeats args.insert(2, n) # number of repeats
n = 1 n = 1
elif m is AIFI: elif m is AIFI:
@ -897,7 +913,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args = [ch[f]] args = [ch[f]]
elif m is Concat: elif m is Concat:
c2 = sum(ch[x] for x in f) c2 = sum(ch[x] for x in f)
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn}: elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
args.append([ch[x] for x in f]) args.append([ch[x] for x in f])
if m is Segment: if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8) args[2] = make_divisible(min(args[2], max_channels) * width, 8)
@ -936,7 +952,10 @@ def yaml_model_load(path):
LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.") LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
path = path.with_name(new_stem + path.suffix) path = path.with_name(new_stem + path.suffix)
unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml if "v10" not in str(path):
unified_path = re.sub(r"(\d+)([nsblmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
else:
unified_path = path
yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path) yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
d = yaml_load(yaml_file) # model dict d = yaml_load(yaml_file) # model dict
d["scale"] = guess_model_scale(path) d["scale"] = guess_model_scale(path)
@ -959,7 +978,7 @@ def guess_model_scale(model_path):
with contextlib.suppress(AttributeError): with contextlib.suppress(AttributeError):
import re import re
return re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x return re.search(r"yolov\d+([nsblmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
return "" return ""
@ -982,7 +1001,7 @@ def guess_model_task(model):
m = cfg["head"][-1][-2].lower() # output module name m = cfg["head"][-1][-2].lower() # output module name
if m in {"classify", "classifier", "cls", "fc"}: if m in {"classify", "classifier", "cls", "fc"}:
return "classify" return "classify"
if m == "detect": if m == "detect" or m == "v10detect":
return "detect" return "detect"
if m == "segment": if m == "segment":
return "segment" return "segment"
@ -1014,7 +1033,7 @@ def guess_model_task(model):
return "pose" return "pose"
elif isinstance(m, OBB): elif isinstance(m, OBB):
return "obb" return "obb"
elif isinstance(m, (Detect, WorldDetect)): elif isinstance(m, (Detect, WorldDetect, v10Detect)):
return "detect" return "detect"
# Guess from model filename # Guess from model filename

View File

@ -147,7 +147,7 @@ class KeypointLoss(nn.Module):
class v8DetectionLoss: class v8DetectionLoss:
"""Criterion class for computing training losses.""" """Criterion class for computing training losses."""
def __init__(self, model): # model must be de-paralleled def __init__(self, model, tal_topk=10): # model must be de-paralleled
"""Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function.""" """Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
device = next(model.parameters()).device # get model device device = next(model.parameters()).device # get model device
h = model.args # hyperparameters h = model.args # hyperparameters
@ -163,7 +163,7 @@ class v8DetectionLoss:
self.use_dfl = m.reg_max > 1 self.use_dfl = m.reg_max > 1
self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device) self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device) self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
@ -713,3 +713,15 @@ class v8OBBLoss(v8DetectionLoss):
b, a, c = pred_dist.shape # batch, anchors, channels b, a, c = pred_dist.shape # batch, anchors, channels
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1) return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
class v10DetectLoss:
def __init__(self, model):
self.one2many = v8DetectionLoss(model, tal_topk=10)
self.one2one = v8DetectionLoss(model, tal_topk=1)
def __call__(self, preds, batch):
one2many = preds["one2many"]
loss_one2many = self.one2many(one2many, batch)
one2one = preds["one2one"]
loss_one2one = self.one2one(one2one, batch)
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))

View File

@ -308,7 +308,8 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
def dist2bbox(distance, anchor_points, xywh=True, dim=-1): def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
"""Transform distance(ltrb) to box(xywh or xyxy).""" """Transform distance(ltrb) to box(xywh or xyxy)."""
lt, rb = distance.chunk(2, dim) assert(distance.shape[dim] == 4)
lt, rb = distance.split([2, 2], dim)
x1y1 = anchor_points - lt x1y1 = anchor_points - lt
x2y2 = anchor_points + rb x2y2 = anchor_points + rb
if xywh: if xywh:

View File

@ -310,10 +310,11 @@ def get_flops(model, imgsz=640):
imgsz = [imgsz, imgsz] # expand if int/float imgsz = [imgsz, imgsz] # expand if int/float
try: try:
# Use stride size for input tensor # Use stride size for input tensor
stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride # stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format # im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs # flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs # return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
raise Exception
except Exception: except Exception:
# Use actual image size for input tensor (i.e. required for RTDETR models) # Use actual image size for input tensor (i.e. required for RTDETR models)
im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format