mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
ultralytics 8.0.221
fix Apple MPS inference bug (#6694)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Johnny <johnnync13@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e504520448
commit
2e71f7f50e
@ -117,8 +117,7 @@ class Results(SimpleClass):
|
|||||||
def update(self, boxes=None, masks=None, probs=None):
|
def update(self, boxes=None, masks=None, probs=None):
|
||||||
"""Update the boxes, masks, and probs attributes of the Results object."""
|
"""Update the boxes, masks, and probs attributes of the Results object."""
|
||||||
if boxes is not None:
|
if boxes is not None:
|
||||||
ops.clip_boxes(boxes, self.orig_shape) # clip boxes
|
self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape)
|
||||||
self.boxes = Boxes(boxes, self.orig_shape)
|
|
||||||
if masks is not None:
|
if masks is not None:
|
||||||
self.masks = Masks(masks, self.orig_shape)
|
self.masks = Masks(masks, self.orig_shape)
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
|
@ -141,7 +141,7 @@ class Pose(Detect):
|
|||||||
else:
|
else:
|
||||||
y = kpts.clone()
|
y = kpts.clone()
|
||||||
if ndim == 3:
|
if ndim == 3:
|
||||||
y[:, 2::3].sigmoid_() # inplace sigmoid
|
y[:, 2::3] = y[:, 2::3].sigmoid() # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug)
|
||||||
y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
|
y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
|
||||||
y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
|
y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
|
||||||
return y
|
return y
|
||||||
|
@ -109,8 +109,7 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
|
|||||||
boxes[..., [0, 2]] -= pad[0] # x padding
|
boxes[..., [0, 2]] -= pad[0] # x padding
|
||||||
boxes[..., [1, 3]] -= pad[1] # y padding
|
boxes[..., [1, 3]] -= pad[1] # y padding
|
||||||
boxes[..., :4] /= gain
|
boxes[..., :4] /= gain
|
||||||
clip_boxes(boxes, img0_shape)
|
return clip_boxes(boxes, img0_shape)
|
||||||
return boxes
|
|
||||||
|
|
||||||
|
|
||||||
def make_divisible(x, divisor):
|
def make_divisible(x, divisor):
|
||||||
@ -179,10 +178,6 @@ def non_max_suppression(
|
|||||||
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
|
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
|
||||||
prediction = prediction[0] # select only inference output
|
prediction = prediction[0] # select only inference output
|
||||||
|
|
||||||
device = prediction.device
|
|
||||||
mps = 'mps' in device.type # Apple MPS
|
|
||||||
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
|
||||||
prediction = prediction.cpu()
|
|
||||||
bs = prediction.shape[0] # batch size
|
bs = prediction.shape[0] # batch size
|
||||||
nc = nc or (prediction.shape[1] - 4) # number of classes
|
nc = nc or (prediction.shape[1] - 4) # number of classes
|
||||||
nm = prediction.shape[1] - nc - 4
|
nm = prediction.shape[1] - nc - 4
|
||||||
@ -256,8 +251,6 @@ def non_max_suppression(
|
|||||||
# i = i[iou.sum(1) > 1] # require redundancy
|
# i = i[iou.sum(1) > 1] # require redundancy
|
||||||
|
|
||||||
output[xi] = x[i]
|
output[xi] = x[i]
|
||||||
if mps:
|
|
||||||
output[xi] = output[xi].to(device)
|
|
||||||
if (time.time() - t) > time_limit:
|
if (time.time() - t) > time_limit:
|
||||||
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
|
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
|
||||||
break # time limit exceeded
|
break # time limit exceeded
|
||||||
@ -272,15 +265,19 @@ def clip_boxes(boxes, shape):
|
|||||||
Args:
|
Args:
|
||||||
boxes (torch.Tensor): the bounding boxes to clip
|
boxes (torch.Tensor): the bounding boxes to clip
|
||||||
shape (tuple): the shape of the image
|
shape (tuple): the shape of the image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(torch.Tensor | numpy.ndarray): Clipped boxes
|
||||||
"""
|
"""
|
||||||
if isinstance(boxes, torch.Tensor): # faster individually
|
if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
|
||||||
boxes[..., 0].clamp_(0, shape[1]) # x1
|
boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
|
||||||
boxes[..., 1].clamp_(0, shape[0]) # y1
|
boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1
|
||||||
boxes[..., 2].clamp_(0, shape[1]) # x2
|
boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2
|
||||||
boxes[..., 3].clamp_(0, shape[0]) # y2
|
boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2
|
||||||
else: # np.array (faster grouped)
|
else: # np.array (faster grouped)
|
||||||
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
|
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
|
||||||
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
|
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
|
||||||
def clip_coords(coords, shape):
|
def clip_coords(coords, shape):
|
||||||
@ -292,14 +289,15 @@ def clip_coords(coords, shape):
|
|||||||
shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
|
shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(None): The function modifies the input `coordinates` in place, by clipping each coordinate to the image boundaries.
|
(torch.Tensor | numpy.ndarray): Clipped coordinates
|
||||||
"""
|
"""
|
||||||
if isinstance(coords, torch.Tensor): # faster individually
|
if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
|
||||||
coords[..., 0].clamp_(0, shape[1]) # x
|
coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x
|
||||||
coords[..., 1].clamp_(0, shape[0]) # y
|
coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y
|
||||||
else: # np.array (faster grouped)
|
else: # np.array (faster grouped)
|
||||||
coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
|
coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
|
||||||
coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
|
coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
|
||||||
|
return coords
|
||||||
|
|
||||||
|
|
||||||
def scale_image(masks, im0_shape, ratio_pad=None):
|
def scale_image(masks, im0_shape, ratio_pad=None):
|
||||||
@ -418,7 +416,7 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
|||||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
|
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
|
||||||
"""
|
"""
|
||||||
if clip:
|
if clip:
|
||||||
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
|
x = clip_boxes(x, (h - eps, w - eps))
|
||||||
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||||
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
|
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
|
||||||
@ -740,7 +738,7 @@ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False
|
|||||||
coords[..., 1] -= pad[1] # y padding
|
coords[..., 1] -= pad[1] # y padding
|
||||||
coords[..., 0] /= gain
|
coords[..., 0] /= gain
|
||||||
coords[..., 1] /= gain
|
coords[..., 1] /= gain
|
||||||
clip_coords(coords, img0_shape)
|
coords = clip_coords(coords, img0_shape)
|
||||||
if normalize:
|
if normalize:
|
||||||
coords[..., 0] /= img0_shape[1] # width
|
coords[..., 0] /= img0_shape[1] # width
|
||||||
coords[..., 1] /= img0_shape[0] # height
|
coords[..., 1] /= img0_shape[0] # height
|
||||||
|
@ -353,7 +353,7 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
|
|||||||
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
|
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
|
||||||
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
|
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
|
||||||
xyxy = ops.xywh2xyxy(b).long()
|
xyxy = ops.xywh2xyxy(b).long()
|
||||||
ops.clip_boxes(xyxy, im.shape)
|
xyxy = ops.clip_boxes(xyxy, im.shape)
|
||||||
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
|
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
|
||||||
if save:
|
if save:
|
||||||
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
||||||
|
Loading…
x
Reference in New Issue
Block a user