mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Add RTDETR Trainer (#2745)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
03bce07848
commit
a0ba8ef5f0
@ -35,6 +35,7 @@ from ultralytics import RTDETR
|
||||
|
||||
model = RTDETR("rtdetr-l.pt")
|
||||
model.info() # display model information
|
||||
model.train(data="coco8.yaml") # train
|
||||
model.predict("path/to/image.jpg") # predict
|
||||
```
|
||||
|
||||
@ -51,7 +52,7 @@ model.predict("path/to/image.jpg") # predict
|
||||
|------------|--------------------|
|
||||
| Inference | :heavy_check_mark: |
|
||||
| Validation | :heavy_check_mark: |
|
||||
| Training | :x: (Coming soon) |
|
||||
| Training | :heavy_check_mark: |
|
||||
|
||||
# Citations and Acknowledgements
|
||||
|
||||
@ -70,4 +71,4 @@ If you use Baidu's RT-DETR in your research or development work, please cite the
|
||||
|
||||
We would like to acknowledge Baidu and the [PaddlePaddle](https://github.com/PaddlePaddle/PaddleDetection) team for creating and maintaining this valuable resource for the computer vision community. Their contribution to the field with the development of the Vision Transformers-based real-time object detector, RT-DETR, is greatly appreciated.
|
||||
|
||||
*Keywords: RT-DETR, Transformer, ViT, Vision Transformers, Baidu RT-DETR, PaddlePaddle, Paddle Paddle RT-DETR, real-time object detection, Vision Transformers-based object detection, pre-trained PaddlePaddle RT-DETR models, Baidu's RT-DETR usage, Ultralytics Python API*
|
||||
*Keywords: RT-DETR, Transformer, ViT, Vision Transformers, Baidu RT-DETR, PaddlePaddle, Paddle Paddle RT-DETR, real-time object detection, Vision Transformers-based object detection, pre-trained PaddlePaddle RT-DETR models, Baidu's RT-DETR usage, Ultralytics Python API*
|
||||
|
@ -8,6 +8,11 @@ keywords: Ultralytics, YOLO, loss functions, object detection, keypoint detectio
|
||||
:::ultralytics.yolo.utils.loss.VarifocalLoss
|
||||
<br><br>
|
||||
|
||||
# FocalLoss
|
||||
---
|
||||
:::ultralytics.yolo.utils.loss.FocalLoss
|
||||
<br><br>
|
||||
|
||||
# BboxLoss
|
||||
---
|
||||
:::ultralytics.yolo.utils.loss.BboxLoss
|
||||
|
@ -3,11 +3,6 @@ description: Explore Ultralytics YOLO's FocalLoss, DetMetrics, PoseMetrics, Clas
|
||||
keywords: YOLOv5, metrics, losses, confusion matrix, detection metrics, pose metrics, classification metrics, intersection over area, intersection over union, keypoint intersection over union, average precision, per class average precision, Ultralytics Docs
|
||||
---
|
||||
|
||||
# FocalLoss
|
||||
---
|
||||
:::ultralytics.yolo.utils.metrics.FocalLoss
|
||||
<br><br>
|
||||
|
||||
# ConfusionMatrix
|
||||
---
|
||||
:::ultralytics.yolo.utils.metrics.ConfusionMatrix
|
||||
|
@ -7,7 +7,7 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics import RTDETR, YOLO
|
||||
from ultralytics.yolo.data.build import load_inference_source
|
||||
from ultralytics.yolo.utils import LINUX, ONLINE, ROOT, SETTINGS
|
||||
|
||||
@ -174,7 +174,10 @@ def test_export_paddle(enabled=False):
|
||||
|
||||
def test_all_model_yamls():
|
||||
for m in list((ROOT / 'models').rglob('yolo*.yaml')):
|
||||
YOLO(m.name)
|
||||
if m.name == 'yolov8-rtdetr.yaml': # except the rtdetr model
|
||||
RTDETR(m.name)
|
||||
else:
|
||||
YOLO(m.name)
|
||||
|
||||
|
||||
def test_workflow():
|
||||
|
46
ultralytics/models/v8/yolov8-rtdetr.yaml
Normal file
46
ultralytics/models/v8/yolov8-rtdetr.yaml
Normal file
@ -0,0 +1,46 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
|
||||
|
||||
# Parameters
|
||||
nc: 80 # number of classes
|
||||
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
|
||||
# [depth, width, max_channels]
|
||||
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
|
||||
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
|
||||
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
|
||||
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
|
||||
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
|
||||
|
||||
# YOLOv8.0n backbone
|
||||
backbone:
|
||||
# [from, repeats, module, args]
|
||||
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
||||
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
||||
- [-1, 3, C2f, [128, True]]
|
||||
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
||||
- [-1, 6, C2f, [256, True]]
|
||||
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
||||
- [-1, 6, C2f, [512, True]]
|
||||
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
||||
- [-1, 3, C2f, [1024, True]]
|
||||
- [-1, 1, SPPF, [1024, 5]] # 9
|
||||
|
||||
# YOLOv8.0n head
|
||||
head:
|
||||
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
||||
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
|
||||
- [-1, 3, C2f, [512]] # 12
|
||||
|
||||
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
|
||||
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
|
||||
- [-1, 3, C2f, [256]] # 15 (P3/8-small)
|
||||
|
||||
- [-1, 1, Conv, [256, 3, 2]]
|
||||
- [[-1, 12], 1, Concat, [1]] # cat head P4
|
||||
- [-1, 3, C2f, [512]] # 18 (P4/16-medium)
|
||||
|
||||
- [-1, 1, Conv, [512, 3, 2]]
|
||||
- [[-1, 9], 1, Concat, [1]] # cat head P5
|
||||
- [-1, 3, C2f, [1024]] # 21 (P5/32-large)
|
||||
|
||||
- [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)
|
@ -163,111 +163,178 @@ class RTDETRDecoder(nn.Module):
|
||||
self,
|
||||
nc=80,
|
||||
ch=(512, 1024, 2048),
|
||||
hidden_dim=256,
|
||||
num_queries=300,
|
||||
strides=(8, 16, 32), # TODO
|
||||
nl=3,
|
||||
num_decoder_points=4,
|
||||
nhead=8,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=1024,
|
||||
hd=256, # hidden dim
|
||||
nq=300, # num queries
|
||||
ndp=4, # num decoder points
|
||||
nh=8, # num head
|
||||
ndl=6, # num decoder layers
|
||||
d_ffn=1024, # dim of feedforward
|
||||
dropout=0.,
|
||||
act=nn.ReLU(),
|
||||
eval_idx=-1,
|
||||
# training args
|
||||
num_denoising=100,
|
||||
nd=100, # num denoising
|
||||
label_noise_ratio=0.5,
|
||||
box_noise_scale=1.0,
|
||||
learnt_init_query=False):
|
||||
super().__init__()
|
||||
assert len(ch) <= nl
|
||||
assert len(strides) == len(ch)
|
||||
for _ in range(nl - len(strides)):
|
||||
strides.append(strides[-1] * 2)
|
||||
|
||||
self.hidden_dim = hidden_dim
|
||||
self.nhead = nhead
|
||||
self.feat_strides = strides
|
||||
self.nl = nl
|
||||
self.hidden_dim = hd
|
||||
self.nhead = nh
|
||||
self.nl = len(ch) # num level
|
||||
self.nc = nc
|
||||
self.num_queries = num_queries
|
||||
self.num_decoder_layers = num_decoder_layers
|
||||
self.num_queries = nq
|
||||
self.num_decoder_layers = ndl
|
||||
|
||||
# backbone feature projection
|
||||
self._build_input_proj_layer(ch)
|
||||
self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch)
|
||||
# NOTE: simplified version but it's not consistent with .pt weights.
|
||||
# self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch)
|
||||
|
||||
# Transformer module
|
||||
decoder_layer = DeformableTransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, dropout, act, nl,
|
||||
num_decoder_points)
|
||||
self.decoder = DeformableTransformerDecoder(hidden_dim, decoder_layer, num_decoder_layers, eval_idx)
|
||||
decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp)
|
||||
self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx)
|
||||
|
||||
# denoising part
|
||||
self.denoising_class_embed = nn.Embedding(nc, hidden_dim)
|
||||
self.num_denoising = num_denoising
|
||||
self.denoising_class_embed = nn.Embedding(nc, hd)
|
||||
self.num_denoising = nd
|
||||
self.label_noise_ratio = label_noise_ratio
|
||||
self.box_noise_scale = box_noise_scale
|
||||
|
||||
# decoder embedding
|
||||
self.learnt_init_query = learnt_init_query
|
||||
if learnt_init_query:
|
||||
self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
|
||||
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
|
||||
self.tgt_embed = nn.Embedding(nq, hd)
|
||||
self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2)
|
||||
|
||||
# encoder head
|
||||
self.enc_output = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim))
|
||||
self.enc_score_head = nn.Linear(hidden_dim, nc)
|
||||
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
|
||||
self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd))
|
||||
self.enc_score_head = nn.Linear(hd, nc)
|
||||
self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3)
|
||||
|
||||
# decoder head
|
||||
self.dec_score_head = nn.ModuleList([nn.Linear(hidden_dim, nc) for _ in range(num_decoder_layers)])
|
||||
self.dec_bbox_head = nn.ModuleList([
|
||||
MLP(hidden_dim, hidden_dim, 4, num_layers=3) for _ in range(num_decoder_layers)])
|
||||
self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)])
|
||||
self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)])
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def forward(self, feats, gt_meta=None):
|
||||
def forward(self, x, batch=None):
|
||||
from ultralytics.vit.utils.ops import get_cdn_group
|
||||
|
||||
# input projection and embedding
|
||||
memory, spatial_shapes, _ = self._get_encoder_input(feats)
|
||||
feats, shapes = self._get_encoder_input(x)
|
||||
|
||||
# prepare denoising training
|
||||
if self.training:
|
||||
raise NotImplementedError
|
||||
# denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
|
||||
# get_contrastive_denoising_training_group(gt_meta,
|
||||
# self.num_classes,
|
||||
# self.num_queries,
|
||||
# self.denoising_class_embed.weight,
|
||||
# self.num_denoising,
|
||||
# self.label_noise_ratio,
|
||||
# self.box_noise_scale)
|
||||
else:
|
||||
denoising_class, denoising_bbox_unact, attn_mask = None, None, None
|
||||
dn_embed, dn_bbox, attn_mask, dn_meta = \
|
||||
get_cdn_group(batch,
|
||||
self.nc,
|
||||
self.num_queries,
|
||||
self.denoising_class_embed.weight,
|
||||
self.num_denoising,
|
||||
self.label_noise_ratio,
|
||||
self.box_noise_scale,
|
||||
self.training)
|
||||
|
||||
target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
|
||||
self._get_decoder_input(memory, spatial_shapes, denoising_class, denoising_bbox_unact)
|
||||
embed, refer_bbox, enc_bboxes, enc_scores = \
|
||||
self._get_decoder_input(feats, shapes, dn_embed, dn_bbox)
|
||||
|
||||
# decoder
|
||||
out_bboxes, out_logits = self.decoder(target,
|
||||
init_ref_points_unact,
|
||||
memory,
|
||||
spatial_shapes,
|
||||
dec_bboxes, dec_scores = self.decoder(embed,
|
||||
refer_bbox,
|
||||
feats,
|
||||
shapes,
|
||||
self.dec_bbox_head,
|
||||
self.dec_score_head,
|
||||
self.query_pos_head,
|
||||
attn_mask=attn_mask)
|
||||
if not self.training:
|
||||
out_logits = out_logits.sigmoid_()
|
||||
return out_bboxes, out_logits # enc_topk_bboxes, enc_topk_logits, dn_meta
|
||||
dec_scores = dec_scores.sigmoid_()
|
||||
return dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
|
||||
|
||||
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
|
||||
anchors = []
|
||||
for i, (h, w) in enumerate(shapes):
|
||||
grid_y, grid_x = torch.meshgrid(torch.arange(end=h, dtype=dtype, device=device),
|
||||
torch.arange(end=w, dtype=dtype, device=device),
|
||||
indexing='ij')
|
||||
grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2)
|
||||
|
||||
valid_WH = torch.tensor([h, w], dtype=dtype, device=device)
|
||||
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2)
|
||||
wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0 ** i)
|
||||
anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4)
|
||||
|
||||
anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4)
|
||||
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1
|
||||
anchors = torch.log(anchors / (1 - anchors))
|
||||
anchors = torch.where(valid_mask, anchors, torch.inf)
|
||||
return anchors, valid_mask
|
||||
|
||||
def _get_encoder_input(self, x):
|
||||
# get projection features
|
||||
x = [self.input_proj[i](feat) for i, feat in enumerate(x)]
|
||||
# get encoder inputs
|
||||
feats = []
|
||||
shapes = []
|
||||
for feat in x:
|
||||
h, w = feat.shape[2:]
|
||||
# [b, c, h, w] -> [b, h*w, c]
|
||||
feats.append(feat.flatten(2).permute(0, 2, 1))
|
||||
# [nl, 2]
|
||||
shapes.append([h, w])
|
||||
|
||||
# [b, h*w, c]
|
||||
feats = torch.cat(feats, 1)
|
||||
return feats, shapes
|
||||
|
||||
def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
|
||||
bs = len(feats)
|
||||
# prepare input for decoder
|
||||
anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
|
||||
features = self.enc_output(torch.where(valid_mask, feats, 0)) # bs, h*w, 256
|
||||
|
||||
enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
|
||||
# dynamic anchors + static content
|
||||
enc_outputs_bboxes = self.enc_bbox_head(features) + anchors # (bs, h*w, 4)
|
||||
|
||||
# query selection
|
||||
# (bs, num_queries)
|
||||
topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1)
|
||||
# (bs, num_queries)
|
||||
batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
|
||||
|
||||
# Unsigmoided
|
||||
refer_bbox = enc_outputs_bboxes[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
||||
# refer_bbox = torch.gather(enc_outputs_bboxes, 1, topk_ind.reshape(bs, self.num_queries).unsqueeze(-1).repeat(1, 1, 4))
|
||||
|
||||
enc_bboxes = refer_bbox.sigmoid()
|
||||
if dn_bbox is not None:
|
||||
refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)
|
||||
if self.training:
|
||||
refer_bbox = refer_bbox.detach()
|
||||
enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
||||
|
||||
if self.learnt_init_query:
|
||||
embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
|
||||
else:
|
||||
embeddings = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
||||
if self.training:
|
||||
embeddings = embeddings.detach()
|
||||
if dn_embed is not None:
|
||||
embeddings = torch.cat([dn_embed, embeddings], 1)
|
||||
|
||||
return embeddings, refer_bbox, enc_bboxes, enc_scores
|
||||
|
||||
# TODO
|
||||
def _reset_parameters(self):
|
||||
# class and bbox head init
|
||||
bias_cls = bias_init_with_prob(0.01)
|
||||
linear_init_(self.enc_score_head)
|
||||
bias_cls = bias_init_with_prob(0.01) / 80 * self.nc
|
||||
# NOTE: the weight initialization in `linear_init_` would cause NaN when training with custom datasets.
|
||||
# linear_init_(self.enc_score_head)
|
||||
constant_(self.enc_score_head.bias, bias_cls)
|
||||
constant_(self.enc_bbox_head.layers[-1].weight, 0.)
|
||||
constant_(self.enc_bbox_head.layers[-1].bias, 0.)
|
||||
for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head):
|
||||
linear_init_(cls_)
|
||||
# linear_init_(cls_)
|
||||
constant_(cls_.bias, bias_cls)
|
||||
constant_(reg_.layers[-1].weight, 0.)
|
||||
constant_(reg_.layers[-1].bias, 0.)
|
||||
@ -280,103 +347,3 @@ class RTDETRDecoder(nn.Module):
|
||||
xavier_uniform_(self.query_pos_head.layers[1].weight)
|
||||
for layer in self.input_proj:
|
||||
xavier_uniform_(layer[0].weight)
|
||||
|
||||
def _build_input_proj_layer(self, ch):
|
||||
self.input_proj = nn.ModuleList()
|
||||
for in_channels in ch:
|
||||
self.input_proj.append(
|
||||
nn.Sequential(nn.Conv2d(in_channels, self.hidden_dim, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(self.hidden_dim)))
|
||||
in_channels = ch[-1]
|
||||
for _ in range(self.nl - len(ch)):
|
||||
self.input_proj.append(
|
||||
nn.Sequential(nn.Conv2D(in_channels, self.hidden_dim, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(self.hidden_dim)))
|
||||
in_channels = self.hidden_dim
|
||||
|
||||
def _generate_anchors(self, spatial_shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
|
||||
anchors = []
|
||||
for lvl, (h, w) in enumerate(spatial_shapes):
|
||||
grid_y, grid_x = torch.meshgrid(torch.arange(end=h, dtype=torch.float32),
|
||||
torch.arange(end=w, dtype=torch.float32),
|
||||
indexing='ij')
|
||||
grid_xy = torch.stack([grid_x, grid_y], -1)
|
||||
|
||||
valid_WH = torch.tensor([h, w]).to(torch.float32)
|
||||
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
|
||||
wh = torch.ones_like(grid_xy) * grid_size * (2.0 ** lvl)
|
||||
anchors.append(torch.concat([grid_xy, wh], -1).reshape([-1, h * w, 4]))
|
||||
|
||||
anchors = torch.concat(anchors, 1)
|
||||
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
|
||||
anchors = torch.log(anchors / (1 - anchors))
|
||||
anchors = torch.where(valid_mask, anchors, torch.inf)
|
||||
return anchors.to(device=device, dtype=dtype), valid_mask.to(device=device)
|
||||
|
||||
def _get_encoder_input(self, feats):
|
||||
# get projection features
|
||||
proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
|
||||
if self.nl > len(proj_feats):
|
||||
len_srcs = len(proj_feats)
|
||||
for i in range(len_srcs, self.nl):
|
||||
if i == len_srcs:
|
||||
proj_feats.append(self.input_proj[i](feats[-1]))
|
||||
else:
|
||||
proj_feats.append(self.input_proj[i](proj_feats[-1]))
|
||||
|
||||
# get encoder inputs
|
||||
feat_flatten = []
|
||||
spatial_shapes = []
|
||||
level_start_index = [0]
|
||||
for feat in proj_feats:
|
||||
_, _, h, w = feat.shape
|
||||
# [b, c, h, w] -> [b, h*w, c]
|
||||
feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
|
||||
# [nl, 2]
|
||||
spatial_shapes.append([h, w])
|
||||
# [l], start index of each level
|
||||
level_start_index.append(h * w + level_start_index[-1])
|
||||
|
||||
# [b, l, c]
|
||||
feat_flatten = torch.concat(feat_flatten, 1)
|
||||
level_start_index.pop()
|
||||
return feat_flatten, spatial_shapes, level_start_index
|
||||
|
||||
def _get_decoder_input(self, memory, spatial_shapes, denoising_class=None, denoising_bbox_unact=None):
|
||||
bs, _, _ = memory.shape
|
||||
# prepare input for decoder
|
||||
anchors, valid_mask = self._generate_anchors(spatial_shapes, dtype=memory.dtype, device=memory.device)
|
||||
memory = torch.where(valid_mask, memory, 0)
|
||||
output_memory = self.enc_output(memory)
|
||||
|
||||
enc_outputs_class = self.enc_score_head(output_memory) # (bs, h*w, nc)
|
||||
enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors # (bs, h*w, 4)
|
||||
|
||||
# (bs, topk)
|
||||
_, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.num_queries, dim=1)
|
||||
# extract region proposal boxes
|
||||
# (bs, topk_ind)
|
||||
batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
|
||||
topk_ind = topk_ind.view(-1)
|
||||
|
||||
# Unsigmoided
|
||||
reference_points_unact = enc_outputs_coord_unact[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
||||
|
||||
enc_topk_bboxes = torch.sigmoid(reference_points_unact)
|
||||
if denoising_bbox_unact is not None:
|
||||
reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)
|
||||
if self.training:
|
||||
reference_points_unact = reference_points_unact.detach()
|
||||
enc_topk_logits = enc_outputs_class[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
||||
|
||||
# extract region features
|
||||
if self.learnt_init_query:
|
||||
target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
|
||||
else:
|
||||
target = output_memory[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
||||
if self.training:
|
||||
target = target.detach()
|
||||
if denoising_class is not None:
|
||||
target = torch.concat([denoising_class, target], 1)
|
||||
|
||||
return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits
|
||||
|
@ -229,23 +229,23 @@ class MSDeformAttn(nn.Module):
|
||||
xavier_uniform_(self.output_proj.weight.data)
|
||||
constant_(self.output_proj.bias.data, 0.)
|
||||
|
||||
def forward(self, query, reference_points, value, value_spatial_shapes, value_mask=None):
|
||||
def forward(self, query, refer_bbox, value, value_shapes, value_mask=None):
|
||||
"""
|
||||
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py
|
||||
Args:
|
||||
query (Tensor): [bs, query_length, C]
|
||||
reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
|
||||
query (torch.Tensor): [bs, query_length, C]
|
||||
refer_bbox (torch.Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
|
||||
bottom-right (1, 1), including padding area
|
||||
value (Tensor): [bs, value_length, C]
|
||||
value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
||||
value (torch.Tensor): [bs, value_length, C]
|
||||
value_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
||||
value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
|
||||
|
||||
Returns:
|
||||
output (Tensor): [bs, Length_{query}, C]
|
||||
"""
|
||||
bs, len_q = query.shape[:2]
|
||||
_, len_v = value.shape[:2]
|
||||
assert sum(s[0] * s[1] for s in value_spatial_shapes) == len_v
|
||||
len_v = value.shape[1]
|
||||
assert sum(s[0] * s[1] for s in value_shapes) == len_v
|
||||
|
||||
value = self.value_proj(value)
|
||||
if value_mask is not None:
|
||||
@ -255,18 +255,17 @@ class MSDeformAttn(nn.Module):
|
||||
attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points)
|
||||
attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points)
|
||||
# N, Len_q, n_heads, n_levels, n_points, 2
|
||||
n = reference_points.shape[-1]
|
||||
if n == 2:
|
||||
offset_normalizer = torch.as_tensor(value_spatial_shapes, dtype=query.dtype, device=query.device).flip(-1)
|
||||
num_points = refer_bbox.shape[-1]
|
||||
if num_points == 2:
|
||||
offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1)
|
||||
add = sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
||||
sampling_locations = reference_points[:, :, None, :, None, :] + add
|
||||
|
||||
elif n == 4:
|
||||
add = sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
|
||||
sampling_locations = reference_points[:, :, None, :, None, :2] + add
|
||||
sampling_locations = refer_bbox[:, :, None, :, None, :] + add
|
||||
elif num_points == 4:
|
||||
add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5
|
||||
sampling_locations = refer_bbox[:, :, None, :, None, :2] + add
|
||||
else:
|
||||
raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {n}.')
|
||||
output = multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights)
|
||||
raise ValueError(f'Last dim of reference_points must be 2 or 4, but got {num_points}.')
|
||||
output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights)
|
||||
output = self.output_proj(output)
|
||||
return output
|
||||
|
||||
@ -308,33 +307,24 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward(self,
|
||||
tgt,
|
||||
reference_points,
|
||||
src,
|
||||
src_spatial_shapes,
|
||||
src_padding_mask=None,
|
||||
attn_mask=None,
|
||||
query_pos=None):
|
||||
def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None):
|
||||
# self attention
|
||||
q = k = self.with_pos_embed(tgt, query_pos)
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.where(attn_mask.astype('bool'), torch.zeros(attn_mask.shape, tgt.dtype),
|
||||
torch.full(attn_mask.shape, float('-inf'), tgt.dtype))
|
||||
tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(embed, query_pos)
|
||||
tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1),
|
||||
attn_mask=attn_mask)[0].transpose(0, 1)
|
||||
embed = embed + self.dropout1(tgt)
|
||||
embed = self.norm1(embed)
|
||||
|
||||
# cross attention
|
||||
tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), reference_points, src, src_spatial_shapes,
|
||||
src_padding_mask)
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt = self.cross_attn(self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes,
|
||||
padding_mask)
|
||||
embed = embed + self.dropout2(tgt)
|
||||
embed = self.norm2(embed)
|
||||
|
||||
# ffn
|
||||
tgt = self.forward_ffn(tgt)
|
||||
embed = self.forward_ffn(embed)
|
||||
|
||||
return tgt
|
||||
return embed
|
||||
|
||||
|
||||
class DeformableTransformerDecoder(nn.Module):
|
||||
@ -349,41 +339,40 @@ class DeformableTransformerDecoder(nn.Module):
|
||||
self.hidden_dim = hidden_dim
|
||||
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
|
||||
|
||||
def forward(self,
|
||||
tgt,
|
||||
reference_points,
|
||||
src,
|
||||
src_spatial_shapes,
|
||||
bbox_head,
|
||||
score_head,
|
||||
query_pos_head,
|
||||
attn_mask=None,
|
||||
src_padding_mask=None):
|
||||
output = tgt
|
||||
dec_out_bboxes = []
|
||||
dec_out_logits = []
|
||||
ref_points = None
|
||||
ref_points_detach = torch.sigmoid(reference_points)
|
||||
def forward(
|
||||
self,
|
||||
embed, # decoder embeddings
|
||||
refer_bbox, # anchor
|
||||
feats, # image features
|
||||
shapes, # feature shapes
|
||||
bbox_head,
|
||||
score_head,
|
||||
pos_mlp,
|
||||
attn_mask=None,
|
||||
padding_mask=None):
|
||||
output = embed
|
||||
dec_bboxes = []
|
||||
dec_cls = []
|
||||
last_refined_bbox = None
|
||||
refer_bbox = refer_bbox.sigmoid()
|
||||
for i, layer in enumerate(self.layers):
|
||||
ref_points_input = ref_points_detach.unsqueeze(2)
|
||||
query_pos_embed = query_pos_head(ref_points_detach)
|
||||
output = layer(output, ref_points_input, src, src_spatial_shapes, src_padding_mask, attn_mask,
|
||||
query_pos_embed)
|
||||
output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox))
|
||||
|
||||
inter_ref_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach))
|
||||
# refine bboxes, (bs, num_queries+num_denoising, 4)
|
||||
refined_bbox = torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(refer_bbox))
|
||||
|
||||
if self.training:
|
||||
dec_out_logits.append(score_head[i](output))
|
||||
dec_cls.append(score_head[i](output))
|
||||
if i == 0:
|
||||
dec_out_bboxes.append(inter_ref_bbox)
|
||||
dec_bboxes.append(refined_bbox)
|
||||
else:
|
||||
dec_out_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points)))
|
||||
dec_bboxes.append(torch.sigmoid(bbox_head[i](output) + inverse_sigmoid(last_refined_bbox)))
|
||||
elif i == self.eval_idx:
|
||||
dec_out_logits.append(score_head[i](output))
|
||||
dec_out_bboxes.append(inter_ref_bbox)
|
||||
dec_cls.append(score_head[i](output))
|
||||
dec_bboxes.append(refined_bbox)
|
||||
break
|
||||
|
||||
ref_points = inter_ref_bbox
|
||||
ref_points_detach = inter_ref_bbox.detach() if self.training else inter_ref_bbox
|
||||
last_refined_bbox = refined_bbox
|
||||
refer_bbox = refined_bbox.detach() if self.training else refined_bbox
|
||||
|
||||
return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
|
||||
return torch.stack(dec_bboxes), torch.stack(dec_cls)
|
||||
|
@ -210,7 +210,9 @@ class BaseModel(nn.Module):
|
||||
"""
|
||||
if not hasattr(self, 'criterion'):
|
||||
self.criterion = self.init_criterion()
|
||||
return self.criterion(self.predict(batch['img']) if preds is None else preds, batch)
|
||||
|
||||
preds = self.forward(batch['img']) if preds is None else preds
|
||||
return self.criterion(preds, batch)
|
||||
|
||||
def init_criterion(self):
|
||||
raise NotImplementedError('compute_loss() needs to be implemented by task heads')
|
||||
@ -410,7 +412,7 @@ class RTDETRDetectionModel(DetectionModel):
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
from ultralytics.vit.utils.loss import RTDETRDetectionLoss
|
||||
|
||||
return RTDETRDetectionLoss(num_classes=self.nc, use_vfl=True)
|
||||
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
|
||||
|
||||
def loss(self, batch, preds=None):
|
||||
if not hasattr(self, 'criterion'):
|
||||
@ -420,31 +422,36 @@ class RTDETRDetectionModel(DetectionModel):
|
||||
# NOTE: preprocess gt_bbox and gt_labels to list.
|
||||
bs = len(img)
|
||||
batch_idx = batch['batch_idx']
|
||||
gt_bbox, gt_class = [], []
|
||||
gt_groups = []
|
||||
for i in range(bs):
|
||||
gt_bbox.append(batch['bboxes'][batch_idx == i].to(img.device))
|
||||
gt_class.append(batch['cls'][batch_idx == i].to(device=img.device, dtype=torch.long))
|
||||
targets = {'cls': gt_class, 'bboxes': gt_bbox}
|
||||
gt_groups.append((batch_idx == i).sum().item())
|
||||
targets = {
|
||||
'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1),
|
||||
'bboxes': batch['bboxes'].to(device=img.device),
|
||||
'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1),
|
||||
'gt_groups': gt_groups}
|
||||
|
||||
preds = self.predict(img, batch=targets) if preds is None else preds
|
||||
dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = preds
|
||||
# NOTE: `dn_meta` means it's eval mode, loss calculation for eval mode is not supported.
|
||||
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds
|
||||
if dn_meta is None:
|
||||
return 0, torch.zeros(3, device=dec_out_bboxes.device)
|
||||
dn_out_bboxes, dec_out_bboxes = torch.split(dec_out_bboxes, dn_meta['dn_num_split'], dim=2)
|
||||
dn_out_logits, dec_out_logits = torch.split(dec_out_logits, dn_meta['dn_num_split'], dim=2)
|
||||
dn_bboxes, dn_scores = None, None
|
||||
else:
|
||||
dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2)
|
||||
dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2)
|
||||
|
||||
out_bboxes = torch.cat([enc_topk_bboxes.unsqueeze(0), dec_out_bboxes])
|
||||
out_logits = torch.cat([enc_topk_logits.unsqueeze(0), dec_out_logits])
|
||||
dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
|
||||
dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
|
||||
|
||||
loss = self.criterion((out_bboxes, out_logits),
|
||||
loss = self.criterion((dec_bboxes, dec_scores),
|
||||
targets,
|
||||
dn_out_bboxes=dn_out_bboxes,
|
||||
dn_out_logits=dn_out_logits,
|
||||
dn_bboxes=dn_bboxes,
|
||||
dn_scores=dn_scores,
|
||||
dn_meta=dn_meta)
|
||||
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']])
|
||||
# NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
|
||||
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
|
||||
device=img.device)
|
||||
|
||||
def predict(self, x, profile=False, visualize=False, batch=None):
|
||||
def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
|
||||
"""
|
||||
Perform a forward pass through the network.
|
||||
|
||||
|
@ -3,4 +3,4 @@
|
||||
from .rtdetr import RTDETR
|
||||
from .sam import SAM
|
||||
|
||||
__all__ = 'RTDETR', 'SAM', 'SAM' # allow simpler import
|
||||
__all__ = 'RTDETR', 'SAM' # allow simpler import
|
||||
|
@ -5,15 +5,15 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.nn.tasks import DetectionModel, attempt_load_one_weight, yaml_model_load
|
||||
from ultralytics.nn.tasks import RTDETRDetectionModel, attempt_load_one_weight, yaml_model_load
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.yolo.utils.torch_utils import model_info
|
||||
from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode
|
||||
|
||||
from ...yolo.utils.torch_utils import smart_inference_mode
|
||||
from .predict import RTDETRPredictor
|
||||
from .train import RTDETRTrainer
|
||||
from .val import RTDETRValidator
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ class RTDETR:
|
||||
raise NotImplementedError('RT-DETR only supports creating from pt file or yaml file.')
|
||||
# Load or create new YOLO model
|
||||
self.predictor = None
|
||||
self.ckpt = None
|
||||
suffix = Path(model).suffix
|
||||
if suffix == '.yaml':
|
||||
self._new(model)
|
||||
@ -34,7 +35,7 @@ class RTDETR:
|
||||
cfg_dict = yaml_model_load(cfg)
|
||||
self.cfg = cfg
|
||||
self.task = 'detect'
|
||||
self.model = DetectionModel(cfg_dict, verbose=verbose) # build model
|
||||
self.model = RTDETRDetectionModel(cfg_dict, verbose=verbose) # build model
|
||||
|
||||
# Below added to allow export from yamls
|
||||
self.model.args = DEFAULT_CFG_DICT # attach args to model
|
||||
@ -42,10 +43,20 @@ class RTDETR:
|
||||
|
||||
@smart_inference_mode()
|
||||
def _load(self, weights: str):
|
||||
self.model, _ = attempt_load_one_weight(weights)
|
||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.model.args = DEFAULT_CFG_DICT # attach args to model
|
||||
self.task = self.model.args['task']
|
||||
|
||||
@smart_inference_mode()
|
||||
def load(self, weights='yolov8n.pt'):
|
||||
"""
|
||||
Transfers parameters with matching names and shapes from 'weights' to model.
|
||||
"""
|
||||
if isinstance(weights, (str, Path)):
|
||||
weights, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.model.load(weights)
|
||||
return self
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
@ -74,8 +85,30 @@ class RTDETR:
|
||||
return self.predictor(source, stream=stream)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""Function trains models but raises an error as RTDETR models do not support training."""
|
||||
raise NotImplementedError("RTDETR models don't support training")
|
||||
"""
|
||||
Trains the model on a given dataset.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||
"""
|
||||
overrides = dict(task='detect', mode='train')
|
||||
overrides.update(kwargs)
|
||||
overrides['deterministic'] = False
|
||||
if not overrides.get('data'):
|
||||
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
||||
if overrides.get('resume'):
|
||||
overrides['resume'] = self.ckpt_path
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.trainer = RTDETRTrainer(overrides=overrides)
|
||||
if not overrides.get('resume'): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||
self.model = self.trainer.model
|
||||
self.trainer.train()
|
||||
# Update model and cfg after training
|
||||
if RANK in (-1, 0):
|
||||
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
|
||||
self.overrides = self.model.args
|
||||
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
|
||||
|
||||
def val(self, **kwargs):
|
||||
"""Run validation given dataset."""
|
||||
|
78
ultralytics/vit/rtdetr/train.py
Normal file
78
ultralytics/vit/rtdetr/train.py
Normal file
@ -0,0 +1,78 @@
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.nn.tasks import RTDETRDetectionModel
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, RANK, colorstr
|
||||
from ultralytics.yolo.v8.detect import DetectionTrainer
|
||||
|
||||
from .val import RTDETRDataset, RTDETRValidator
|
||||
|
||||
|
||||
class RTDETRTrainer(DetectionTrainer):
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Return a YOLO detection model."""
|
||||
model = RTDETRDetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
return model
|
||||
|
||||
def build_dataset(self, img_path, mode='val', batch=None):
|
||||
"""Build RTDETR Dataset
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||
"""
|
||||
return RTDETRDataset(
|
||||
img_path=img_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch,
|
||||
augment=mode == 'train', # no augmentation
|
||||
hyp=self.args,
|
||||
rect=False, # no rect
|
||||
cache=self.args.cache or None,
|
||||
prefix=colorstr(f'{mode}: '),
|
||||
data=self.data)
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns a DetectionValidator for RTDETR model validation."""
|
||||
self.loss_names = 'giou_loss', 'cls_loss', 'l1_loss'
|
||||
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocesses a batch of images by scaling and converting to float."""
|
||||
batch = super().preprocess_batch(batch)
|
||||
bs = len(batch['img'])
|
||||
batch_idx = batch['batch_idx']
|
||||
gt_bbox, gt_class = [], []
|
||||
for i in range(bs):
|
||||
gt_bbox.append(batch['bboxes'][batch_idx == i].to(batch_idx.device))
|
||||
gt_class.append(batch['cls'][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
|
||||
return batch
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Train and optimize RTDETR model given training data and device."""
|
||||
model = 'rtdetr-l.yaml'
|
||||
data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist")
|
||||
device = cfg.device if cfg.device is not None else ''
|
||||
|
||||
# NOTE: F.grid_sample which is in rt-detr does not support deterministic=True
|
||||
# NOTE: amp training causes nan outputs and end with error while doing bipartite graph matching
|
||||
args = dict(model=model,
|
||||
data=data,
|
||||
device=device,
|
||||
imgsz=640,
|
||||
exist_ok=True,
|
||||
batch=4,
|
||||
deterministic=False,
|
||||
amp=False)
|
||||
trainer = RTDETRTrainer(overrides=args)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
@ -2,10 +2,12 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.data import YOLODataset
|
||||
from ultralytics.yolo.data.augment import Compose, Format, LetterBox
|
||||
from ultralytics.yolo.data.augment import Compose, Format, v8_transforms
|
||||
from ultralytics.yolo.utils import colorstr, ops
|
||||
from ultralytics.yolo.v8.detect import DetectionValidator
|
||||
|
||||
@ -18,9 +20,41 @@ class RTDETRDataset(YOLODataset):
|
||||
def __init__(self, *args, data=None, **kwargs):
|
||||
super().__init__(*args, data=data, use_segments=False, use_keypoints=False, **kwargs)
|
||||
|
||||
# NOTE: add stretch version load_image for rtdetr mosaic
|
||||
def load_image(self, i):
|
||||
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
|
||||
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
||||
if im is None: # not cached in RAM
|
||||
if fn.exists(): # load npy
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if im is None:
|
||||
raise FileNotFoundError(f'Image Not Found {f}')
|
||||
h0, w0 = im.shape[:2] # orig hw
|
||||
im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# Add to buffer if training with augmentations
|
||||
if self.augment:
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
|
||||
self.buffer.append(i)
|
||||
if len(self.buffer) >= self.max_buffer_length:
|
||||
j = self.buffer.pop(0)
|
||||
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
|
||||
|
||||
return im, (h0, w0), im.shape[:2]
|
||||
|
||||
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
"""Temporarily, only for evaluation."""
|
||||
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
|
||||
if self.augment:
|
||||
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
||||
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
||||
transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
|
||||
else:
|
||||
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
|
||||
transforms = Compose([])
|
||||
transforms.append(
|
||||
Format(bbox_format='xywh',
|
||||
normalize=True,
|
||||
@ -65,6 +99,8 @@ class RTDETRValidator(DetectionValidator):
|
||||
# Do not need threshold for evaluation as only got 300 boxes here.
|
||||
# idx = score > self.args.conf
|
||||
pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter
|
||||
# sort by confidence to correctly get internal metrics.
|
||||
pred = pred[score.argsort(descending=True)]
|
||||
outputs[i] = pred # [idx]
|
||||
|
||||
return outputs
|
||||
@ -100,7 +136,8 @@ class RTDETRValidator(DetectionValidator):
|
||||
tbox[..., [0, 2]] *= shape[1] # native-space pred
|
||||
tbox[..., [1, 3]] *= shape[0] # native-space pred
|
||||
labelsn = torch.cat((cls, tbox), 1) # native-space labels
|
||||
correct_bboxes = self._process_batch(predn, labelsn)
|
||||
# NOTE: To get correct metrics, the inputs of `_process_batch` should always be float32 type.
|
||||
correct_bboxes = self._process_batch(predn.float(), labelsn)
|
||||
# TODO: maybe remove these `self.` arguments as they already are member variable
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(predn, labelsn)
|
||||
|
@ -256,10 +256,8 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if not fill_labels:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
|
@ -18,14 +18,12 @@ class Sam(nn.Module):
|
||||
mask_threshold: float = 0.0
|
||||
image_format: str = 'RGB'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_encoder: ImageEncoderViT,
|
||||
prompt_encoder: PromptEncoder,
|
||||
mask_decoder: MaskDecoder,
|
||||
pixel_mean: List[float] = [123.675, 116.28, 103.53],
|
||||
pixel_std: List[float] = [58.395, 57.12, 57.375],
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
image_encoder: ImageEncoderViT,
|
||||
prompt_encoder: PromptEncoder,
|
||||
mask_decoder: MaskDecoder,
|
||||
pixel_mean: List[float] = None,
|
||||
pixel_std: List[float] = None) -> None:
|
||||
"""
|
||||
SAM predicts object masks from an image and input prompts.
|
||||
|
||||
@ -38,6 +36,10 @@ class Sam(nn.Module):
|
||||
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
||||
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
||||
"""
|
||||
if pixel_mean is None:
|
||||
pixel_mean = [123.675, 116.28, 103.53]
|
||||
if pixel_std is None:
|
||||
pixel_std = [58.395, 57.12, 57.375]
|
||||
super().__init__()
|
||||
self.image_encoder = image_encoder
|
||||
self.prompt_encoder = prompt_encoder
|
||||
|
291
ultralytics/vit/utils/loss.py
Normal file
291
ultralytics/vit/utils/loss.py
Normal file
@ -0,0 +1,291 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.vit.utils.ops import HungarianMatcher
|
||||
from ultralytics.yolo.utils.loss import FocalLoss, VarifocalLoss
|
||||
from ultralytics.yolo.utils.metrics import bbox_iou
|
||||
|
||||
|
||||
class DETRLoss(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
nc=80,
|
||||
loss_gain=None,
|
||||
aux_loss=True,
|
||||
use_fl=True,
|
||||
use_vfl=False,
|
||||
use_uni_match=False,
|
||||
uni_match_ind=0):
|
||||
"""
|
||||
Args:
|
||||
nc (int): The number of classes.
|
||||
loss_gain (dict): The coefficient of loss.
|
||||
aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
|
||||
use_focal_loss (bool): Use focal loss or not.
|
||||
use_vfl (bool): Use VarifocalLoss or not.
|
||||
use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch.
|
||||
uni_match_ind (int): The fixed indices of a layer.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if loss_gain is None:
|
||||
loss_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'no_object': 0.1, 'mask': 1, 'dice': 1}
|
||||
self.nc = nc
|
||||
self.matcher = HungarianMatcher(cost_gain={'class': 2, 'bbox': 5, 'giou': 2})
|
||||
self.loss_gain = loss_gain
|
||||
self.aux_loss = aux_loss
|
||||
self.fl = FocalLoss() if use_fl else None
|
||||
self.vfl = VarifocalLoss() if use_vfl else None
|
||||
|
||||
self.use_uni_match = use_uni_match
|
||||
self.uni_match_ind = uni_match_ind
|
||||
self.device = None
|
||||
|
||||
def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=''):
|
||||
# logits: [b, query, num_classes], gt_class: list[[n, 1]]
|
||||
name_class = f'loss_class{postfix}'
|
||||
bs, nq = pred_scores.shape[:2]
|
||||
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
|
||||
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
|
||||
one_hot.scatter_(2, targets.unsqueeze(-1), 1)
|
||||
one_hot = one_hot[..., :-1]
|
||||
gt_scores = gt_scores.view(bs, nq, 1) * one_hot
|
||||
|
||||
if self.fl:
|
||||
if num_gts and self.vfl:
|
||||
loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
|
||||
else:
|
||||
loss_cls = self.fl(pred_scores, one_hot.float())
|
||||
loss_cls /= max(num_gts, 1) / nq
|
||||
else:
|
||||
loss_cls = nn.BCEWithLogitsLoss(reduction='none')(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
|
||||
|
||||
return {name_class: loss_cls.squeeze() * self.loss_gain['class']}
|
||||
|
||||
def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=''):
|
||||
# boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
||||
name_bbox = f'loss_bbox{postfix}'
|
||||
name_giou = f'loss_giou{postfix}'
|
||||
|
||||
loss = {}
|
||||
if len(gt_bboxes) == 0:
|
||||
loss[name_bbox] = torch.tensor(0., device=self.device)
|
||||
loss[name_giou] = torch.tensor(0., device=self.device)
|
||||
return loss
|
||||
|
||||
loss[name_bbox] = self.loss_gain['bbox'] * F.l1_loss(pred_bboxes, gt_bboxes, reduction='sum') / len(gt_bboxes)
|
||||
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
|
||||
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
|
||||
loss[name_giou] = self.loss_gain['giou'] * loss[name_giou]
|
||||
loss = {k: v.squeeze() for k, v in loss.items()}
|
||||
return loss
|
||||
|
||||
def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
|
||||
# masks: [b, query, h, w], gt_mask: list[[n, H, W]]
|
||||
name_mask = f'loss_mask{postfix}'
|
||||
name_dice = f'loss_dice{postfix}'
|
||||
|
||||
loss = {}
|
||||
if sum(len(a) for a in gt_mask) == 0:
|
||||
loss[name_mask] = torch.tensor(0., device=self.device)
|
||||
loss[name_dice] = torch.tensor(0., device=self.device)
|
||||
return loss
|
||||
|
||||
num_gts = len(gt_mask)
|
||||
src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
|
||||
src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
|
||||
# TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
|
||||
loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
|
||||
torch.tensor([num_gts], dtype=torch.float32))
|
||||
loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
|
||||
return loss
|
||||
|
||||
def _dice_loss(self, inputs, targets, num_gts):
|
||||
inputs = F.sigmoid(inputs)
|
||||
inputs = inputs.flatten(1)
|
||||
targets = targets.flatten(1)
|
||||
numerator = 2 * (inputs * targets).sum(1)
|
||||
denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
return loss.sum() / num_gts
|
||||
|
||||
def _get_loss_aux(self,
|
||||
pred_bboxes,
|
||||
pred_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
match_indices=None,
|
||||
postfix='',
|
||||
masks=None,
|
||||
gt_mask=None):
|
||||
"""Get auxiliary losses"""
|
||||
# NOTE: loss class, bbox, giou, mask, dice
|
||||
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
|
||||
if match_indices is None and self.use_uni_match:
|
||||
match_indices = self.matcher(pred_bboxes[self.uni_match_ind],
|
||||
pred_scores[self.uni_match_ind],
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=masks[self.uni_match_ind] if masks is not None else None,
|
||||
gt_mask=gt_mask)
|
||||
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
|
||||
aux_masks = masks[i] if masks is not None else None
|
||||
loss_ = self._get_loss(aux_bboxes,
|
||||
aux_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=aux_masks,
|
||||
gt_mask=gt_mask,
|
||||
postfix=postfix,
|
||||
match_indices=match_indices)
|
||||
loss[0] += loss_[f'loss_class{postfix}']
|
||||
loss[1] += loss_[f'loss_bbox{postfix}']
|
||||
loss[2] += loss_[f'loss_giou{postfix}']
|
||||
# if masks is not None and gt_mask is not None:
|
||||
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
|
||||
# loss[3] += loss_[f'loss_mask{postfix}']
|
||||
# loss[4] += loss_[f'loss_dice{postfix}']
|
||||
|
||||
loss = {
|
||||
f'loss_class_aux{postfix}': loss[0],
|
||||
f'loss_bbox_aux{postfix}': loss[1],
|
||||
f'loss_giou_aux{postfix}': loss[2]}
|
||||
# if masks is not None and gt_mask is not None:
|
||||
# loss[f'loss_mask_aux{postfix}'] = loss[3]
|
||||
# loss[f'loss_dice_aux{postfix}'] = loss[4]
|
||||
return loss
|
||||
|
||||
def _get_index(self, match_indices):
|
||||
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
|
||||
src_idx = torch.cat([src for (src, _) in match_indices])
|
||||
dst_idx = torch.cat([dst for (_, dst) in match_indices])
|
||||
return (batch_idx, src_idx), dst_idx
|
||||
|
||||
def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
|
||||
pred_assigned = torch.cat([
|
||||
t[I] if len(I) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||
for t, (I, _) in zip(pred_bboxes, match_indices)])
|
||||
gt_assigned = torch.cat([
|
||||
t[J] if len(J) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||
for t, (_, J) in zip(gt_bboxes, match_indices)])
|
||||
return pred_assigned, gt_assigned
|
||||
|
||||
def _get_loss(self,
|
||||
pred_bboxes,
|
||||
pred_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=None,
|
||||
gt_mask=None,
|
||||
postfix='',
|
||||
match_indices=None):
|
||||
"""Get losses"""
|
||||
if match_indices is None:
|
||||
match_indices = self.matcher(pred_bboxes,
|
||||
pred_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=masks,
|
||||
gt_mask=gt_mask)
|
||||
|
||||
idx, gt_idx = self._get_index(match_indices)
|
||||
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
|
||||
|
||||
bs, nq = pred_scores.shape[:2]
|
||||
targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
|
||||
targets[idx] = gt_cls[gt_idx]
|
||||
|
||||
gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
|
||||
if len(gt_bboxes):
|
||||
gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
|
||||
|
||||
loss = {}
|
||||
loss.update(self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix))
|
||||
loss.update(self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix))
|
||||
# if masks is not None and gt_mask is not None:
|
||||
# loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
|
||||
return loss
|
||||
|
||||
def forward(self, pred_bboxes, pred_scores, batch, postfix='', **kwargs):
|
||||
"""
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): [l, b, query, 4]
|
||||
pred_scores (torch.Tensor): [l, b, query, num_classes]
|
||||
batch (dict): A dict includes:
|
||||
gt_cls (torch.Tensor) with shape [num_gts, ],
|
||||
gt_bboxes (torch.Tensor): [num_gts, 4],
|
||||
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||||
postfix (str): postfix of loss name.
|
||||
"""
|
||||
self.device = pred_bboxes.device
|
||||
match_indices = kwargs.get('match_indices', None)
|
||||
gt_cls, gt_bboxes, gt_groups = batch['cls'], batch['bboxes'], batch['gt_groups']
|
||||
|
||||
total_loss = self._get_loss(pred_bboxes[-1],
|
||||
pred_scores[-1],
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
postfix=postfix,
|
||||
match_indices=match_indices)
|
||||
|
||||
if self.aux_loss:
|
||||
total_loss.update(
|
||||
self._get_loss_aux(pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices,
|
||||
postfix))
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
class RTDETRDetectionLoss(DETRLoss):
|
||||
|
||||
def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
|
||||
pred_bboxes, pred_scores = preds
|
||||
total_loss = super().forward(pred_bboxes, pred_scores, batch)
|
||||
|
||||
if dn_meta is not None:
|
||||
dn_pos_idx, dn_num_group = dn_meta['dn_pos_idx'], dn_meta['dn_num_group']
|
||||
assert len(batch['gt_groups']) == len(dn_pos_idx)
|
||||
|
||||
# denoising match indices
|
||||
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch['gt_groups'])
|
||||
|
||||
# compute denoising training loss
|
||||
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix='_dn', match_indices=match_indices)
|
||||
total_loss.update(dn_loss)
|
||||
else:
|
||||
total_loss.update({f'{k}_dn': torch.tensor(0., device=self.device) for k in total_loss.keys()})
|
||||
|
||||
return total_loss
|
||||
|
||||
@staticmethod
|
||||
def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
|
||||
"""Get the match indices for denoising.
|
||||
|
||||
Args:
|
||||
dn_pos_idx (List[torch.Tensor]): A list includes positive indices of denoising.
|
||||
dn_num_group (int): The number of groups of denoising.
|
||||
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||||
|
||||
Returns:
|
||||
dn_match_indices (List(tuple)): Matched indices.
|
||||
|
||||
"""
|
||||
dn_match_indices = []
|
||||
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||||
for i, num_gt in enumerate(gt_groups):
|
||||
if num_gt > 0:
|
||||
gt_idx = torch.arange(end=num_gt, dtype=torch.int32) + idx_groups[i]
|
||||
gt_idx = gt_idx.repeat(dn_num_group)
|
||||
assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
|
||||
f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
|
||||
dn_match_indices.append((dn_pos_idx[i], gt_idx))
|
||||
else:
|
||||
dn_match_indices.append((torch.zeros([0], dtype=torch.int32), torch.zeros([0], dtype=torch.int32)))
|
||||
return dn_match_indices
|
230
ultralytics/vit/utils/ops.py
Normal file
230
ultralytics/vit/utils/ops.py
Normal file
@ -0,0 +1,230 @@
|
||||
# TODO: license
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from ultralytics.yolo.utils.metrics import bbox_iou
|
||||
from ultralytics.yolo.utils.ops import xywh2xyxy, xyxy2xywh
|
||||
|
||||
|
||||
class HungarianMatcher(nn.Module):
|
||||
|
||||
def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0):
|
||||
"""
|
||||
Args:
|
||||
matcher_coeff (dict): The coefficient of hungarian matcher cost.
|
||||
"""
|
||||
super().__init__()
|
||||
if cost_gain is None:
|
||||
cost_gain = {'class': 1, 'bbox': 5, 'giou': 2, 'mask': 1, 'dice': 1}
|
||||
self.cost_gain = cost_gain
|
||||
self.use_fl = use_fl
|
||||
self.with_mask = with_mask
|
||||
self.num_sample_points = num_sample_points
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None):
|
||||
"""
|
||||
Args:
|
||||
pred_bboxes (Tensor): [b, query, 4]
|
||||
pred_scores (Tensor): [b, query, num_classes]
|
||||
gt_cls (torch.Tensor) with shape [num_gts, ]
|
||||
gt_bboxes (torch.Tensor): [num_gts, 4]
|
||||
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||||
masks (Tensor|None): [b, query, h, w]
|
||||
gt_mask (List(Tensor)): list[[n, H, W]]
|
||||
|
||||
Returns:
|
||||
A list of size batch_size, containing tuples of (index_i, index_j) where:
|
||||
- index_i is the indices of the selected predictions (in order)
|
||||
- index_j is the indices of the corresponding selected targets (in order)
|
||||
For each batch element, it holds:
|
||||
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
||||
"""
|
||||
bs, nq, nc = pred_scores.shape
|
||||
|
||||
if sum(gt_groups) == 0:
|
||||
return [(torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)) for _ in range(bs)]
|
||||
|
||||
# We flatten to compute the cost matrices in a batch
|
||||
# [batch_size * num_queries, num_classes]
|
||||
pred_scores = pred_scores.detach().view(-1, nc)
|
||||
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
|
||||
# [batch_size * num_queries, 4]
|
||||
pred_bboxes = pred_bboxes.detach().view(-1, 4)
|
||||
|
||||
# Compute the classification cost
|
||||
pred_scores = pred_scores[:, gt_cls]
|
||||
if self.use_fl:
|
||||
neg_cost_class = (1 - self.alpha) * (pred_scores ** self.gamma) * (-(1 - pred_scores + 1e-8).log())
|
||||
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
|
||||
cost_class = pos_cost_class - neg_cost_class
|
||||
else:
|
||||
cost_class = -pred_scores
|
||||
|
||||
# Compute the L1 cost between boxes
|
||||
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
|
||||
|
||||
# Compute the GIoU cost between boxes, (bs*num_queries, num_gt)
|
||||
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
|
||||
|
||||
# Final cost matrix
|
||||
C = self.cost_gain['class'] * cost_class + \
|
||||
self.cost_gain['bbox'] * cost_bbox + \
|
||||
self.cost_gain['giou'] * cost_giou
|
||||
# Compute the mask cost and dice cost
|
||||
if self.with_mask:
|
||||
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
|
||||
|
||||
C = C.view(bs, nq, -1).cpu()
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
|
||||
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||||
# (idx for queries, idx for gt)
|
||||
return [(torch.tensor(i, dtype=torch.int32), torch.tensor(j, dtype=torch.int32) + gt_groups[k])
|
||||
for k, (i, j) in enumerate(indices)]
|
||||
|
||||
def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
|
||||
assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
|
||||
# all masks share the same set of points for efficient matching
|
||||
sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
|
||||
sample_points = 2.0 * sample_points - 1.0
|
||||
|
||||
out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
|
||||
out_mask = out_mask.flatten(0, 1)
|
||||
|
||||
tgt_mask = torch.cat(gt_mask).unsqueeze(1)
|
||||
sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
|
||||
tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
|
||||
|
||||
with torch.cuda.amp.autocast(False):
|
||||
# binary cross entropy cost
|
||||
pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
|
||||
neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
|
||||
cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
|
||||
cost_mask /= self.num_sample_points
|
||||
|
||||
# dice cost
|
||||
out_mask = F.sigmoid(out_mask)
|
||||
numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
|
||||
denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
|
||||
cost_dice = 1 - (numerator + 1) / (denominator + 1)
|
||||
|
||||
C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
|
||||
return C
|
||||
|
||||
|
||||
def get_cdn_group(batch,
|
||||
num_classes,
|
||||
num_queries,
|
||||
class_embed,
|
||||
num_dn=100,
|
||||
cls_noise_ratio=0.5,
|
||||
box_noise_scale=1.0,
|
||||
training=False):
|
||||
"""Get contrastive denoising training group
|
||||
|
||||
Args:
|
||||
batch (dict): A dict includes:
|
||||
gt_cls (torch.Tensor) with shape [num_gts, ],
|
||||
gt_bboxes (torch.Tensor): [num_gts, 4],
|
||||
gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
|
||||
num_classes (int): Number of classes.
|
||||
num_queries (int): Number of queries.
|
||||
class_embed (torch.Tensor): Embedding weights to map cls to embedding space.
|
||||
num_dn (int): Number of denoising.
|
||||
cls_noise_ratio (float): Noise ratio for class.
|
||||
box_noise_scale (float): Noise scale for bbox.
|
||||
training (bool): If it's training or not.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if (not training) or num_dn <= 0:
|
||||
return None, None, None, None
|
||||
gt_groups = batch['gt_groups']
|
||||
total_num = sum(gt_groups)
|
||||
max_nums = max(gt_groups)
|
||||
if max_nums == 0:
|
||||
return None, None, None, None
|
||||
|
||||
num_group = num_dn // max_nums
|
||||
num_group = 1 if num_group == 0 else num_group
|
||||
# pad gt to max_num of a batch
|
||||
bs = len(gt_groups)
|
||||
gt_cls = batch['cls'] # (bs*num, )
|
||||
gt_bbox = batch['bboxes'] # bs*num, 4
|
||||
b_idx = batch['batch_idx']
|
||||
|
||||
# each group has positive and negative queries.
|
||||
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
||||
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
|
||||
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
|
||||
|
||||
# positive and negative mask
|
||||
# (bs*num*num_group, ), the second total_num*num_group part as negative samples
|
||||
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
|
||||
|
||||
if cls_noise_ratio > 0:
|
||||
# half of bbox prob
|
||||
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
|
||||
idx = torch.nonzero(mask).squeeze(-1)
|
||||
# randomly put a new one here
|
||||
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
|
||||
dn_cls[idx] = new_label
|
||||
|
||||
if box_noise_scale > 0:
|
||||
known_bbox = xywh2xyxy(dn_bbox)
|
||||
|
||||
diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
|
||||
|
||||
rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
|
||||
rand_part = torch.rand_like(dn_bbox)
|
||||
rand_part[neg_idx] += 1.0
|
||||
rand_part *= rand_sign
|
||||
known_bbox += rand_part * diff
|
||||
known_bbox.clip_(min=0.0, max=1.0)
|
||||
dn_bbox = xyxy2xywh(known_bbox)
|
||||
dn_bbox = inverse_sigmoid(dn_bbox)
|
||||
|
||||
# total denoising queries
|
||||
num_dn = int(max_nums * 2 * num_group)
|
||||
# class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)])
|
||||
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
|
||||
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
|
||||
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
|
||||
|
||||
map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
|
||||
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
|
||||
|
||||
map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
|
||||
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
|
||||
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
|
||||
|
||||
tgt_size = num_dn + num_queries
|
||||
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
|
||||
# match query cannot see the reconstruct
|
||||
attn_mask[num_dn:, :num_dn] = True
|
||||
# reconstruct cannot see each other
|
||||
for i in range(num_group):
|
||||
if i == 0:
|
||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
|
||||
if i == num_group - 1:
|
||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * i * 2] = True
|
||||
else:
|
||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), max_nums * 2 * (i + 1):num_dn] = True
|
||||
attn_mask[max_nums * 2 * i:max_nums * 2 * (i + 1), :max_nums * 2 * i] = True
|
||||
dn_meta = {
|
||||
'dn_pos_idx': [p.reshape(-1) for p in pos_idx.cpu().split([n for n in gt_groups], dim=1)],
|
||||
'dn_num_group': num_group,
|
||||
'dn_num_split': [num_dn, num_queries]}
|
||||
|
||||
return padding_cls.to(class_embed.device), padding_bbox.to(class_embed.device), attn_mask.to(
|
||||
class_embed.device), dn_meta
|
||||
|
||||
|
||||
def inverse_sigmoid(x, eps=1e-6):
|
||||
x = x.clip(min=0., max=1.)
|
||||
return torch.log(x / (1 - x + eps) + eps)
|
@ -759,7 +759,7 @@ class Format:
|
||||
return masks, instances, cls
|
||||
|
||||
|
||||
def v8_transforms(dataset, imgsz, hyp):
|
||||
def v8_transforms(dataset, imgsz, hyp, stretch=False):
|
||||
"""Convert images to a size suitable for YOLOv8 training."""
|
||||
pre_transform = Compose([
|
||||
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
|
||||
@ -770,7 +770,7 @@ def v8_transforms(dataset, imgsz, hyp):
|
||||
scale=hyp.scale,
|
||||
shear=hyp.shear,
|
||||
perspective=hyp.perspective,
|
||||
pre_transform=LetterBox(new_shape=(imgsz, imgsz)),
|
||||
pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
|
||||
)])
|
||||
flip_idx = dataset.data.get('flip_idx', None) # for keypoints augmentation
|
||||
if dataset.use_keypoints:
|
||||
|
@ -278,7 +278,8 @@ class BaseTrainer:
|
||||
self.epoch_time_start = time.time()
|
||||
self.train_time_start = time.time()
|
||||
nb = len(self.train_loader) # number of batches
|
||||
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
||||
nw = max(round(self.args.warmup_epochs *
|
||||
nb), 100) if self.args.warmup_epochs > 0 else -1 # number of warmup iterations
|
||||
last_opt_step = -1
|
||||
self.run_callbacks('on_train_start')
|
||||
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||
|
@ -24,10 +24,34 @@ class VarifocalLoss(nn.Module):
|
||||
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
|
||||
weight).sum()
|
||||
weight).mean(1).sum()
|
||||
return loss
|
||||
|
||||
|
||||
# Losses
|
||||
class FocalLoss(nn.Module):
|
||||
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
|
||||
|
||||
def __init__(self, ):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, pred, label, gamma=1.5, alpha=0.25):
|
||||
"""Calculates and updates confusion matrix for object detection/classification tasks."""
|
||||
loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
|
||||
# p_t = torch.exp(-loss)
|
||||
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
|
||||
|
||||
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
|
||||
pred_prob = pred.sigmoid() # prob from logits
|
||||
p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
|
||||
modulating_factor = (1.0 - p_t) ** gamma
|
||||
loss *= modulating_factor
|
||||
if alpha > 0:
|
||||
alpha_factor = label * alpha + (1 - label) * (1 - alpha)
|
||||
loss *= alpha_factor
|
||||
return loss.mean(1).sum()
|
||||
|
||||
|
||||
class BboxLoss(nn.Module):
|
||||
|
||||
def __init__(self, reg_max, use_dfl=False):
|
||||
|
@ -9,7 +9,6 @@ from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, SimpleClass, TryExcept, plt_settings
|
||||
|
||||
@ -175,40 +174,6 @@ def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#iss
|
||||
return 1.0 - 0.5 * eps, 0.5 * eps
|
||||
|
||||
|
||||
# Losses
|
||||
class FocalLoss(nn.Module):
|
||||
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
|
||||
|
||||
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
|
||||
"""Initialize FocalLoss object with given loss function and hyperparameters."""
|
||||
super().__init__()
|
||||
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
self.reduction = loss_fcn.reduction
|
||||
self.loss_fcn.reduction = 'none' # required to apply FL to each element
|
||||
|
||||
def forward(self, pred, true):
|
||||
"""Calculates and updates confusion matrix for object detection/classification tasks."""
|
||||
loss = self.loss_fcn(pred, true)
|
||||
# p_t = torch.exp(-loss)
|
||||
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
|
||||
|
||||
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
|
||||
pred_prob = torch.sigmoid(pred) # prob from logits
|
||||
p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
|
||||
alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
|
||||
modulating_factor = (1.0 - p_t) ** self.gamma
|
||||
loss *= alpha_factor * modulating_factor
|
||||
|
||||
if self.reduction == 'mean':
|
||||
return loss.mean()
|
||||
elif self.reduction == 'sum':
|
||||
return loss.sum()
|
||||
else: # 'None'
|
||||
return loss
|
||||
|
||||
|
||||
class ConfusionMatrix:
|
||||
"""
|
||||
A class for calculating and updating a confusion matrix for object detection and classification tasks.
|
||||
|
@ -327,6 +327,9 @@ def init_seeds(seed=0, deterministic=False):
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
else:
|
||||
LOGGER.warning('WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.')
|
||||
else:
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = False
|
||||
|
||||
|
||||
class ModelEMA:
|
||||
|
Loading…
x
Reference in New Issue
Block a user