mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
[RTDETR]Fix val loss (#3280)
This commit is contained in:
parent
d8701b42ca
commit
9d1e5567de
@ -158,6 +158,7 @@ class Classify(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class RTDETRDecoder(nn.Module):
|
class RTDETRDecoder(nn.Module):
|
||||||
|
export = False # export mode
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -246,9 +247,12 @@ class RTDETRDecoder(nn.Module):
|
|||||||
self.dec_score_head,
|
self.dec_score_head,
|
||||||
self.query_pos_head,
|
self.query_pos_head,
|
||||||
attn_mask=attn_mask)
|
attn_mask=attn_mask)
|
||||||
if not self.training:
|
x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
|
||||||
dec_scores = dec_scores.sigmoid_()
|
if self.training:
|
||||||
return dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta
|
return x
|
||||||
|
# (bs, 300, 4+nc)
|
||||||
|
y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1)
|
||||||
|
return y if self.export else (y, x)
|
||||||
|
|
||||||
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
|
def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device='cpu', eps=1e-2):
|
||||||
anchors = []
|
anchors = []
|
||||||
|
@ -432,7 +432,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|||||||
'gt_groups': gt_groups}
|
'gt_groups': gt_groups}
|
||||||
|
|
||||||
preds = self.predict(img, batch=targets) if preds is None else preds
|
preds = self.predict(img, batch=targets) if preds is None else preds
|
||||||
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds
|
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
|
||||||
if dn_meta is None:
|
if dn_meta is None:
|
||||||
dn_bboxes, dn_scores = None, None
|
dn_bboxes, dn_scores = None, None
|
||||||
else:
|
else:
|
||||||
|
@ -12,8 +12,8 @@ class RTDETRPredictor(BasePredictor):
|
|||||||
|
|
||||||
def postprocess(self, preds, img, orig_imgs):
|
def postprocess(self, preds, img, orig_imgs):
|
||||||
"""Postprocess predictions and returns a list of Results objects."""
|
"""Postprocess predictions and returns a list of Results objects."""
|
||||||
bboxes, scores = preds[:2] # (1, bs, 300, 4), (1, bs, 300, nc)
|
nd = preds[0].shape[-1]
|
||||||
bboxes, scores = bboxes.squeeze_(0), scores.squeeze_(0)
|
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
||||||
results = []
|
results = []
|
||||||
for i, bbox in enumerate(bboxes): # (300, 4)
|
for i, bbox in enumerate(bboxes): # (300, 4)
|
||||||
bbox = ops.xywh2xyxy(bbox)
|
bbox = ops.xywh2xyxy(bbox)
|
||||||
|
@ -89,9 +89,8 @@ class RTDETRValidator(DetectionValidator):
|
|||||||
|
|
||||||
def postprocess(self, preds):
|
def postprocess(self, preds):
|
||||||
"""Apply Non-maximum suppression to prediction outputs."""
|
"""Apply Non-maximum suppression to prediction outputs."""
|
||||||
bboxes, scores = preds[:2] # (1, bs, 300, 4), (1, bs, 300, nc)
|
bs, _, nd = preds[0].shape
|
||||||
bboxes, scores = bboxes.squeeze_(0), scores.squeeze_(0) # (bs, 300, 4)
|
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
||||||
bs = len(bboxes)
|
|
||||||
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
|
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
|
||||||
for i, bbox in enumerate(bboxes): # (300, 4)
|
for i, bbox in enumerate(bboxes): # (300, 4)
|
||||||
bbox = ops.xywh2xyxy(bbox)
|
bbox = ops.xywh2xyxy(bbox)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user