mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-25 10:25:39 +08:00 
			
		
		
		
	Fix TFLite INT8 for OBB (#7989)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
		
							parent
							
								
									e62d9cfe07
								
							
						
					
					
						commit
						70a6ef9c7e
					
				| @ -59,16 +59,17 @@ class Detect(nn.Module): | ||||
|             cls = x_cat[:, self.reg_max * 4 :] | ||||
|         else: | ||||
|             box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) | ||||
|         dbox = self.decode_bboxes(box) | ||||
| 
 | ||||
|         if self.export and self.format in ("tflite", "edgetpu"): | ||||
|             # Precompute normalization factor to increase numerical stability | ||||
|             # See https://github.com/ultralytics/ultralytics/issues/7371 | ||||
|             img_h = shape[2] | ||||
|             img_w = shape[3] | ||||
|             img_size = torch.tensor([img_w, img_h, img_w, img_h], device=box.device).reshape(1, 4, 1) | ||||
|             norm = self.strides / (self.stride[0] * img_size) | ||||
|             dbox = dist2bbox(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2], xywh=True, dim=1) | ||||
|             grid_h = shape[2] | ||||
|             grid_w = shape[3] | ||||
|             grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) | ||||
|             norm = self.strides / (self.stride[0] * grid_size) | ||||
|             dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) | ||||
|         else: | ||||
|             dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides | ||||
| 
 | ||||
|         y = torch.cat((dbox, cls.sigmoid()), 1) | ||||
|         return y if self.export else (y, x) | ||||
| @ -82,9 +83,9 @@ class Detect(nn.Module): | ||||
|             a[-1].bias.data[:] = 1.0  # box | ||||
|             b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img) | ||||
| 
 | ||||
|     def decode_bboxes(self, bboxes): | ||||
|     def decode_bboxes(self, bboxes, anchors): | ||||
|         """Decode bounding boxes.""" | ||||
|         return dist2bbox(self.dfl(bboxes), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides | ||||
|         return dist2bbox(bboxes, anchors, xywh=True, dim=1) | ||||
| 
 | ||||
| 
 | ||||
| class Segment(Detect): | ||||
| @ -139,9 +140,9 @@ class OBB(Detect): | ||||
|             return x, angle | ||||
|         return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle)) | ||||
| 
 | ||||
|     def decode_bboxes(self, bboxes): | ||||
|     def decode_bboxes(self, bboxes, anchors): | ||||
|         """Decode rotated bounding boxes.""" | ||||
|         return dist2rbox(self.dfl(bboxes), self.angle, self.anchors.unsqueeze(0), dim=1) * self.strides | ||||
|         return dist2rbox(bboxes, self.angle, anchors, dim=1) | ||||
| 
 | ||||
| 
 | ||||
| class Pose(Detect): | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 AdamP
						AdamP