mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-31 22:55:40 +08:00 
			
		
		
		
	ultralytics 8.0.221 fix Apple MPS inference bug (#6694)
				
					
				
			Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Johnny <johnnync13@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									e504520448
								
							
						
					
					
						commit
						2e71f7f50e
					
				| @ -117,8 +117,7 @@ class Results(SimpleClass): | |||||||
|     def update(self, boxes=None, masks=None, probs=None): |     def update(self, boxes=None, masks=None, probs=None): | ||||||
|         """Update the boxes, masks, and probs attributes of the Results object.""" |         """Update the boxes, masks, and probs attributes of the Results object.""" | ||||||
|         if boxes is not None: |         if boxes is not None: | ||||||
|             ops.clip_boxes(boxes, self.orig_shape)  # clip boxes |             self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape) | ||||||
|             self.boxes = Boxes(boxes, self.orig_shape) |  | ||||||
|         if masks is not None: |         if masks is not None: | ||||||
|             self.masks = Masks(masks, self.orig_shape) |             self.masks = Masks(masks, self.orig_shape) | ||||||
|         if probs is not None: |         if probs is not None: | ||||||
|  | |||||||
| @ -141,7 +141,7 @@ class Pose(Detect): | |||||||
|         else: |         else: | ||||||
|             y = kpts.clone() |             y = kpts.clone() | ||||||
|             if ndim == 3: |             if ndim == 3: | ||||||
|                 y[:, 2::3].sigmoid_()  # inplace sigmoid |                 y[:, 2::3] = y[:, 2::3].sigmoid()  # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug) | ||||||
|             y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides |             y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides | ||||||
|             y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides |             y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides | ||||||
|             return y |             return y | ||||||
|  | |||||||
| @ -109,8 +109,7 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True): | |||||||
|         boxes[..., [0, 2]] -= pad[0]  # x padding |         boxes[..., [0, 2]] -= pad[0]  # x padding | ||||||
|         boxes[..., [1, 3]] -= pad[1]  # y padding |         boxes[..., [1, 3]] -= pad[1]  # y padding | ||||||
|     boxes[..., :4] /= gain |     boxes[..., :4] /= gain | ||||||
|     clip_boxes(boxes, img0_shape) |     return clip_boxes(boxes, img0_shape) | ||||||
|     return boxes |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def make_divisible(x, divisor): | def make_divisible(x, divisor): | ||||||
| @ -179,10 +178,6 @@ def non_max_suppression( | |||||||
|     if isinstance(prediction, (list, tuple)):  # YOLOv8 model in validation model, output = (inference_out, loss_out) |     if isinstance(prediction, (list, tuple)):  # YOLOv8 model in validation model, output = (inference_out, loss_out) | ||||||
|         prediction = prediction[0]  # select only inference output |         prediction = prediction[0]  # select only inference output | ||||||
| 
 | 
 | ||||||
|     device = prediction.device |  | ||||||
|     mps = 'mps' in device.type  # Apple MPS |  | ||||||
|     if mps:  # MPS not fully supported yet, convert tensors to CPU before NMS |  | ||||||
|         prediction = prediction.cpu() |  | ||||||
|     bs = prediction.shape[0]  # batch size |     bs = prediction.shape[0]  # batch size | ||||||
|     nc = nc or (prediction.shape[1] - 4)  # number of classes |     nc = nc or (prediction.shape[1] - 4)  # number of classes | ||||||
|     nm = prediction.shape[1] - nc - 4 |     nm = prediction.shape[1] - nc - 4 | ||||||
| @ -256,8 +251,6 @@ def non_max_suppression( | |||||||
|         #         i = i[iou.sum(1) > 1]  # require redundancy |         #         i = i[iou.sum(1) > 1]  # require redundancy | ||||||
| 
 | 
 | ||||||
|         output[xi] = x[i] |         output[xi] = x[i] | ||||||
|         if mps: |  | ||||||
|             output[xi] = output[xi].to(device) |  | ||||||
|         if (time.time() - t) > time_limit: |         if (time.time() - t) > time_limit: | ||||||
|             LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded') |             LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded') | ||||||
|             break  # time limit exceeded |             break  # time limit exceeded | ||||||
| @ -270,17 +263,21 @@ def clip_boxes(boxes, shape): | |||||||
|     Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape. |     Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape. | ||||||
| 
 | 
 | ||||||
|     Args: |     Args: | ||||||
|       boxes (torch.Tensor): the bounding boxes to clip |         boxes (torch.Tensor): the bounding boxes to clip | ||||||
|       shape (tuple): the shape of the image |         shape (tuple): the shape of the image | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         (torch.Tensor | numpy.ndarray): Clipped boxes | ||||||
|     """ |     """ | ||||||
|     if isinstance(boxes, torch.Tensor):  # faster individually |     if isinstance(boxes, torch.Tensor):  # faster individually (WARNING: inplace .clamp_() Apple MPS bug) | ||||||
|         boxes[..., 0].clamp_(0, shape[1])  # x1 |         boxes[..., 0] = boxes[..., 0].clamp(0, shape[1])  # x1 | ||||||
|         boxes[..., 1].clamp_(0, shape[0])  # y1 |         boxes[..., 1] = boxes[..., 1].clamp(0, shape[0])  # y1 | ||||||
|         boxes[..., 2].clamp_(0, shape[1])  # x2 |         boxes[..., 2] = boxes[..., 2].clamp(0, shape[1])  # x2 | ||||||
|         boxes[..., 3].clamp_(0, shape[0])  # y2 |         boxes[..., 3] = boxes[..., 3].clamp(0, shape[0])  # y2 | ||||||
|     else:  # np.array (faster grouped) |     else:  # np.array (faster grouped) | ||||||
|         boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1])  # x1, x2 |         boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1])  # x1, x2 | ||||||
|         boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0])  # y1, y2 |         boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0])  # y1, y2 | ||||||
|  |     return boxes | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def clip_coords(coords, shape): | def clip_coords(coords, shape): | ||||||
| @ -292,14 +289,15 @@ def clip_coords(coords, shape): | |||||||
|         shape (tuple): A tuple of integers representing the size of the image in the format (height, width). |         shape (tuple): A tuple of integers representing the size of the image in the format (height, width). | ||||||
| 
 | 
 | ||||||
|     Returns: |     Returns: | ||||||
|         (None): The function modifies the input `coordinates` in place, by clipping each coordinate to the image boundaries. |         (torch.Tensor | numpy.ndarray): Clipped coordinates | ||||||
|     """ |     """ | ||||||
|     if isinstance(coords, torch.Tensor):  # faster individually |     if isinstance(coords, torch.Tensor):  # faster individually (WARNING: inplace .clamp_() Apple MPS bug) | ||||||
|         coords[..., 0].clamp_(0, shape[1])  # x |         coords[..., 0] = coords[..., 0].clamp(0, shape[1])  # x | ||||||
|         coords[..., 1].clamp_(0, shape[0])  # y |         coords[..., 1] = coords[..., 1].clamp(0, shape[0])  # y | ||||||
|     else:  # np.array (faster grouped) |     else:  # np.array (faster grouped) | ||||||
|         coords[..., 0] = coords[..., 0].clip(0, shape[1])  # x |         coords[..., 0] = coords[..., 0].clip(0, shape[1])  # x | ||||||
|         coords[..., 1] = coords[..., 1].clip(0, shape[0])  # y |         coords[..., 1] = coords[..., 1].clip(0, shape[0])  # y | ||||||
|  |     return coords | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def scale_image(masks, im0_shape, ratio_pad=None): | def scale_image(masks, im0_shape, ratio_pad=None): | ||||||
| @ -418,7 +416,7 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): | |||||||
|         y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format |         y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format | ||||||
|     """ |     """ | ||||||
|     if clip: |     if clip: | ||||||
|         clip_boxes(x, (h - eps, w - eps))  # warning: inplace clip |         x = clip_boxes(x, (h - eps, w - eps)) | ||||||
|     assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}' |     assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}' | ||||||
|     y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)  # faster than clone/copy |     y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)  # faster than clone/copy | ||||||
|     y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w  # x center |     y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w  # x center | ||||||
| @ -740,7 +738,7 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False | |||||||
|         coords[..., 1] -= pad[1]  # y padding |         coords[..., 1] -= pad[1]  # y padding | ||||||
|     coords[..., 0] /= gain |     coords[..., 0] /= gain | ||||||
|     coords[..., 1] /= gain |     coords[..., 1] /= gain | ||||||
|     clip_coords(coords, img0_shape) |     coords = clip_coords(coords, img0_shape) | ||||||
|     if normalize: |     if normalize: | ||||||
|         coords[..., 0] /= img0_shape[1]  # width |         coords[..., 0] /= img0_shape[1]  # width | ||||||
|         coords[..., 1] /= img0_shape[0]  # height |         coords[..., 1] /= img0_shape[0]  # height | ||||||
|  | |||||||
| @ -353,7 +353,7 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, | |||||||
|         b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1)  # attempt rectangle to square |         b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1)  # attempt rectangle to square | ||||||
|     b[:, 2:] = b[:, 2:] * gain + pad  # box wh * gain + pad |     b[:, 2:] = b[:, 2:] * gain + pad  # box wh * gain + pad | ||||||
|     xyxy = ops.xywh2xyxy(b).long() |     xyxy = ops.xywh2xyxy(b).long() | ||||||
|     ops.clip_boxes(xyxy, im.shape) |     xyxy = ops.clip_boxes(xyxy, im.shape) | ||||||
|     crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)] |     crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)] | ||||||
|     if save: |     if save: | ||||||
|         file.parent.mkdir(parents=True, exist_ok=True)  # make directory |         file.parent.mkdir(parents=True, exist_ok=True)  # make directory | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Glenn Jocher
						Glenn Jocher