mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Fix conversion ops using clone
and copy
(#4438)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
162e4035eb
commit
8f79ce45c1
@ -13,7 +13,7 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ultralytics.utils import LOGGER, checks, clean_url, emojis, is_online, url2file
|
from ultralytics.utils import LOGGER, TQDM_BAR_FORMAT, checks, clean_url, emojis, is_online, url2file
|
||||||
|
|
||||||
GITHUB_ASSET_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in ('', '6', '-cls', '-seg', '-pose')] + \
|
GITHUB_ASSET_NAMES = [f'yolov8{k}{suffix}.pt' for k in 'nsmlx' for suffix in ('', '6', '-cls', '-seg', '-pose')] + \
|
||||||
[f'yolov5{k}u.pt' for k in 'nsmlx'] + \
|
[f'yolov5{k}u.pt' for k in 'nsmlx'] + \
|
||||||
@ -287,7 +287,6 @@ def safe_download(url,
|
|||||||
if method == 'torch':
|
if method == 'torch':
|
||||||
torch.hub.download_url_to_file(url, f, progress=progress)
|
torch.hub.download_url_to_file(url, f, progress=progress)
|
||||||
else:
|
else:
|
||||||
from ultralytics.utils import TQDM_BAR_FORMAT
|
|
||||||
with request.urlopen(url) as response, tqdm(total=int(response.getheader('Content-Length', 0)),
|
with request.urlopen(url) as response, tqdm(total=int(response.getheader('Content-Length', 0)),
|
||||||
desc=desc,
|
desc=desc,
|
||||||
disable=not progress,
|
disable=not progress,
|
||||||
|
@ -24,15 +24,15 @@ to_2tuple = _ntuple(2)
|
|||||||
to_4tuple = _ntuple(4)
|
to_4tuple = _ntuple(4)
|
||||||
|
|
||||||
# `xyxy` means left top and right bottom
|
# `xyxy` means left top and right bottom
|
||||||
# `xywh` means center x, center y and width, height(yolo format)
|
# `xywh` means center x, center y and width, height(YOLO format)
|
||||||
# `ltwh` means left top and width, height(coco format)
|
# `ltwh` means left top and width, height(COCO format)
|
||||||
_formats = ['xyxy', 'xywh', 'ltwh']
|
_formats = ['xyxy', 'xywh', 'ltwh']
|
||||||
|
|
||||||
__all__ = 'Bboxes', # tuple or list
|
__all__ = 'Bboxes', # tuple or list
|
||||||
|
|
||||||
|
|
||||||
class Bboxes:
|
class Bboxes:
|
||||||
"""Now only numpy is supported."""
|
"""Bounding Boxes class. Only numpy variables are supported."""
|
||||||
|
|
||||||
def __init__(self, bboxes, format='xyxy') -> None:
|
def __init__(self, bboxes, format='xyxy') -> None:
|
||||||
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
|
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
|
||||||
@ -43,40 +43,18 @@ class Bboxes:
|
|||||||
self.format = format
|
self.format = format
|
||||||
# self.normalized = normalized
|
# self.normalized = normalized
|
||||||
|
|
||||||
# def convert(self, format):
|
|
||||||
# assert format in _formats
|
|
||||||
# if self.format == format:
|
|
||||||
# bboxes = self.bboxes
|
|
||||||
# elif self.format == "xyxy":
|
|
||||||
# if format == "xywh":
|
|
||||||
# bboxes = xyxy2xywh(self.bboxes)
|
|
||||||
# else:
|
|
||||||
# bboxes = xyxy2ltwh(self.bboxes)
|
|
||||||
# elif self.format == "xywh":
|
|
||||||
# if format == "xyxy":
|
|
||||||
# bboxes = xywh2xyxy(self.bboxes)
|
|
||||||
# else:
|
|
||||||
# bboxes = xywh2ltwh(self.bboxes)
|
|
||||||
# else:
|
|
||||||
# if format == "xyxy":
|
|
||||||
# bboxes = ltwh2xyxy(self.bboxes)
|
|
||||||
# else:
|
|
||||||
# bboxes = ltwh2xywh(self.bboxes)
|
|
||||||
#
|
|
||||||
# return Bboxes(bboxes, format)
|
|
||||||
|
|
||||||
def convert(self, format):
|
def convert(self, format):
|
||||||
"""Converts bounding box format from one type to another."""
|
"""Converts bounding box format from one type to another."""
|
||||||
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
|
assert format in _formats, f'Invalid bounding box format: {format}, format must be one of {_formats}'
|
||||||
if self.format == format:
|
if self.format == format:
|
||||||
return
|
return
|
||||||
elif self.format == 'xyxy':
|
elif self.format == 'xyxy':
|
||||||
bboxes = xyxy2xywh(self.bboxes) if format == 'xywh' else xyxy2ltwh(self.bboxes)
|
func = xyxy2xywh if format == 'xywh' else xyxy2ltwh
|
||||||
elif self.format == 'xywh':
|
elif self.format == 'xywh':
|
||||||
bboxes = xywh2xyxy(self.bboxes) if format == 'xyxy' else xywh2ltwh(self.bboxes)
|
func = xywh2xyxy if format == 'xyxy' else xywh2ltwh
|
||||||
else:
|
else:
|
||||||
bboxes = ltwh2xyxy(self.bboxes) if format == 'xyxy' else ltwh2xywh(self.bboxes)
|
func = ltwh2xyxy if format == 'xyxy' else ltwh2xywh
|
||||||
self.bboxes = bboxes
|
self.bboxes = func(self.bboxes)
|
||||||
self.format = format
|
self.format = format
|
||||||
|
|
||||||
def areas(self):
|
def areas(self):
|
||||||
|
@ -344,7 +344,8 @@ def xyxy2xywh(x):
|
|||||||
Returns:
|
Returns:
|
||||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
||||||
"""
|
"""
|
||||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||||
|
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||||
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
|
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
|
||||||
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
|
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
|
||||||
y[..., 2] = x[..., 2] - x[..., 0] # width
|
y[..., 2] = x[..., 2] - x[..., 0] # width
|
||||||
@ -362,7 +363,8 @@ def xywh2xyxy(x):
|
|||||||
Returns:
|
Returns:
|
||||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
|
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
|
||||||
"""
|
"""
|
||||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||||
|
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||||
dw = x[..., 2] / 2 # half-width
|
dw = x[..., 2] / 2 # half-width
|
||||||
dh = x[..., 3] / 2 # half-height
|
dh = x[..., 3] / 2 # half-height
|
||||||
y[..., 0] = x[..., 0] - dw # top left x
|
y[..., 0] = x[..., 0] - dw # top left x
|
||||||
@ -386,7 +388,8 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
|||||||
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
|
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
|
||||||
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
|
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
|
||||||
"""
|
"""
|
||||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||||
|
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||||
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
|
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
|
||||||
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
|
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
|
||||||
y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
|
y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
|
||||||
@ -410,7 +413,8 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
|||||||
"""
|
"""
|
||||||
if clip:
|
if clip:
|
||||||
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
|
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
|
||||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
assert x.shape[-1] == 4, f'input shape last dimension expected 4 but input shape is {x.shape}'
|
||||||
|
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
|
||||||
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
|
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
|
||||||
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
|
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
|
||||||
y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
|
y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
|
||||||
@ -431,7 +435,7 @@ def xyn2xy(x, w=640, h=640, padw=0, padh=0):
|
|||||||
Returns:
|
Returns:
|
||||||
y (np.ndarray | torch.Tensor): The x and y coordinates of the top left corner of the bounding box
|
y (np.ndarray | torch.Tensor): The x and y coordinates of the top left corner of the bounding box
|
||||||
"""
|
"""
|
||||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||||
y[..., 0] = w * x[..., 0] + padw # top left x
|
y[..., 0] = w * x[..., 0] + padw # top left x
|
||||||
y[..., 1] = h * x[..., 1] + padh # top left y
|
y[..., 1] = h * x[..., 1] + padh # top left y
|
||||||
return y
|
return y
|
||||||
@ -446,9 +450,9 @@ def xywh2ltwh(x):
|
|||||||
Returns:
|
Returns:
|
||||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
|
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
|
||||||
"""
|
"""
|
||||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||||
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
||||||
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
@ -461,9 +465,9 @@ def xyxy2ltwh(x):
|
|||||||
Returns:
|
Returns:
|
||||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
|
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
|
||||||
"""
|
"""
|
||||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||||
y[:, 2] = x[:, 2] - x[:, 0] # width
|
y[..., 2] = x[..., 2] - x[..., 0] # width
|
||||||
y[:, 3] = x[:, 3] - x[:, 1] # height
|
y[..., 3] = x[..., 3] - x[..., 1] # height
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
@ -474,9 +478,9 @@ def ltwh2xywh(x):
|
|||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): the input tensor
|
x (torch.Tensor): the input tensor
|
||||||
"""
|
"""
|
||||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||||
y[:, 0] = x[:, 0] + x[:, 2] / 2 # center x
|
y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
|
||||||
y[:, 1] = x[:, 1] + x[:, 3] / 2 # center y
|
y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
@ -569,9 +573,9 @@ def ltwh2xyxy(x):
|
|||||||
Returns:
|
Returns:
|
||||||
y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
|
y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
|
||||||
"""
|
"""
|
||||||
y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x)
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||||
y[:, 2] = x[:, 2] + x[:, 0] # width
|
y[..., 2] = x[..., 2] + x[..., 0] # width
|
||||||
y[:, 3] = x[:, 3] + x[:, 1] # height
|
y[..., 3] = x[..., 3] + x[..., 1] # height
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user