mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
Update ViT ops.py to torch.long
(#3508)
Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
parent
8a11eda4a9
commit
31b46bf2b4
@ -284,11 +284,11 @@ class RTDETRDetectionLoss(DETRLoss):
|
|||||||
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||||||
for i, num_gt in enumerate(gt_groups):
|
for i, num_gt in enumerate(gt_groups):
|
||||||
if num_gt > 0:
|
if num_gt > 0:
|
||||||
gt_idx = torch.arange(end=num_gt, dtype=torch.int32) + idx_groups[i]
|
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
|
||||||
gt_idx = gt_idx.repeat(dn_num_group)
|
gt_idx = gt_idx.repeat(dn_num_group)
|
||||||
assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
|
assert len(dn_pos_idx[i]) == len(gt_idx), 'Expected the same length, '
|
||||||
f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
|
f'but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively.'
|
||||||
dn_match_indices.append((dn_pos_idx[i], gt_idx))
|
dn_match_indices.append((dn_pos_idx[i], gt_idx))
|
||||||
else:
|
else:
|
||||||
dn_match_indices.append((torch.zeros([0], dtype=torch.int32), torch.zeros([0], dtype=torch.int32)))
|
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
|
||||||
return dn_match_indices
|
return dn_match_indices
|
||||||
|
@ -71,7 +71,7 @@ class HungarianMatcher(nn.Module):
|
|||||||
bs, nq, nc = pred_scores.shape
|
bs, nq, nc = pred_scores.shape
|
||||||
|
|
||||||
if sum(gt_groups) == 0:
|
if sum(gt_groups) == 0:
|
||||||
return [(torch.tensor([], dtype=torch.int32), torch.tensor([], dtype=torch.int32)) for _ in range(bs)]
|
return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
|
||||||
|
|
||||||
# We flatten to compute the cost matrices in a batch
|
# We flatten to compute the cost matrices in a batch
|
||||||
# [batch_size * num_queries, num_classes]
|
# [batch_size * num_queries, num_classes]
|
||||||
@ -107,7 +107,7 @@ class HungarianMatcher(nn.Module):
|
|||||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
|
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
|
||||||
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||||||
# (idx for queries, idx for gt)
|
# (idx for queries, idx for gt)
|
||||||
return [(torch.tensor(i, dtype=torch.int32), torch.tensor(j, dtype=torch.int32) + gt_groups[k])
|
return [(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
|
||||||
for k, (i, j) in enumerate(indices)]
|
for k, (i, j) in enumerate(indices)]
|
||||||
|
|
||||||
def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
|
def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user