mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Removed redundency in RTDETRDecoder (#4118)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
c06e9ae630
commit
8870084645
@ -307,8 +307,6 @@ class RTDETRDecoder(nn.Module):
|
||||
features = self.enc_output(valid_mask * feats) # bs, h*w, 256
|
||||
|
||||
enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc)
|
||||
# dynamic anchors + static content
|
||||
enc_outputs_bboxes = self.enc_bbox_head(features) + anchors # (bs, h*w, 4)
|
||||
|
||||
# query selection
|
||||
# (bs, num_queries)
|
||||
@ -316,22 +314,23 @@ class RTDETRDecoder(nn.Module):
|
||||
# (bs, num_queries)
|
||||
batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1)
|
||||
|
||||
# Unsigmoided
|
||||
refer_bbox = enc_outputs_bboxes[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
||||
# refer_bbox = torch.gather(enc_outputs_bboxes, 1, topk_ind.reshape(bs, self.num_queries).unsqueeze(-1).repeat(1, 1, 4))
|
||||
# (bs, num_queries, 256)
|
||||
top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
||||
# (bs, num_queries, 4)
|
||||
top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1)
|
||||
|
||||
# dynamic anchors + static content
|
||||
refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors
|
||||
|
||||
enc_bboxes = refer_bbox.sigmoid()
|
||||
if dn_bbox is not None:
|
||||
refer_bbox = torch.cat([dn_bbox, refer_bbox], 1)
|
||||
if self.training:
|
||||
refer_bbox = refer_bbox.detach()
|
||||
enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
||||
|
||||
if self.learnt_init_query:
|
||||
embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
|
||||
else:
|
||||
embeddings = features[batch_ind, topk_ind].view(bs, self.num_queries, -1)
|
||||
if self.training:
|
||||
embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features
|
||||
if self.training:
|
||||
refer_bbox = refer_bbox.detach()
|
||||
if not self.learnt_init_query:
|
||||
embeddings = embeddings.detach()
|
||||
if dn_embed is not None:
|
||||
embeddings = torch.cat([dn_embed, embeddings], 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user