mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Fix TFLite INT8 for OBB (#7989)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
e62d9cfe07
commit
70a6ef9c7e
@ -59,16 +59,17 @@ class Detect(nn.Module):
|
|||||||
cls = x_cat[:, self.reg_max * 4 :]
|
cls = x_cat[:, self.reg_max * 4 :]
|
||||||
else:
|
else:
|
||||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||||
dbox = self.decode_bboxes(box)
|
|
||||||
|
|
||||||
if self.export and self.format in ("tflite", "edgetpu"):
|
if self.export and self.format in ("tflite", "edgetpu"):
|
||||||
# Precompute normalization factor to increase numerical stability
|
# Precompute normalization factor to increase numerical stability
|
||||||
# See https://github.com/ultralytics/ultralytics/issues/7371
|
# See https://github.com/ultralytics/ultralytics/issues/7371
|
||||||
img_h = shape[2]
|
grid_h = shape[2]
|
||||||
img_w = shape[3]
|
grid_w = shape[3]
|
||||||
img_size = torch.tensor([img_w, img_h, img_w, img_h], device=box.device).reshape(1, 4, 1)
|
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
|
||||||
norm = self.strides / (self.stride[0] * img_size)
|
norm = self.strides / (self.stride[0] * grid_size)
|
||||||
dbox = dist2bbox(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2], xywh=True, dim=1)
|
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
|
||||||
|
else:
|
||||||
|
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
|
||||||
|
|
||||||
y = torch.cat((dbox, cls.sigmoid()), 1)
|
y = torch.cat((dbox, cls.sigmoid()), 1)
|
||||||
return y if self.export else (y, x)
|
return y if self.export else (y, x)
|
||||||
@ -82,9 +83,9 @@ class Detect(nn.Module):
|
|||||||
a[-1].bias.data[:] = 1.0 # box
|
a[-1].bias.data[:] = 1.0 # box
|
||||||
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
|
||||||
|
|
||||||
def decode_bboxes(self, bboxes):
|
def decode_bboxes(self, bboxes, anchors):
|
||||||
"""Decode bounding boxes."""
|
"""Decode bounding boxes."""
|
||||||
return dist2bbox(self.dfl(bboxes), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
return dist2bbox(bboxes, anchors, xywh=True, dim=1)
|
||||||
|
|
||||||
|
|
||||||
class Segment(Detect):
|
class Segment(Detect):
|
||||||
@ -139,9 +140,9 @@ class OBB(Detect):
|
|||||||
return x, angle
|
return x, angle
|
||||||
return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
|
return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
|
||||||
|
|
||||||
def decode_bboxes(self, bboxes):
|
def decode_bboxes(self, bboxes, anchors):
|
||||||
"""Decode rotated bounding boxes."""
|
"""Decode rotated bounding boxes."""
|
||||||
return dist2rbox(self.dfl(bboxes), self.angle, self.anchors.unsqueeze(0), dim=1) * self.strides
|
return dist2rbox(bboxes, self.angle, anchors, dim=1)
|
||||||
|
|
||||||
|
|
||||||
class Pose(Detect):
|
class Pose(Detect):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user