mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-07-07 22:04:53 +08:00
Add utils.ops
and nn.modules
to tests (#4484)
This commit is contained in:
parent
1cec0185a1
commit
6da8f7f51e
@ -17,10 +17,6 @@ keywords: Ultralytics, hub functions, model export, dataset check, reset model,
|
|||||||
## ::: ultralytics.hub.logout
|
## ::: ultralytics.hub.logout
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
---
|
|
||||||
## ::: ultralytics.hub.start
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
---
|
---
|
||||||
## ::: ultralytics.hub.reset_model
|
## ::: ultralytics.hub.reset_model
|
||||||
<br><br>
|
<br><br>
|
||||||
|
@ -33,10 +33,6 @@ keywords: Ultralytics, YOLO, YOLOv3, YOLOv4, metrics, confusion matrix, detectio
|
|||||||
## ::: ultralytics.utils.metrics.ClassifyMetrics
|
## ::: ultralytics.utils.metrics.ClassifyMetrics
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
---
|
|
||||||
## ::: ultralytics.utils.metrics.box_area
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
---
|
---
|
||||||
## ::: ultralytics.utils.metrics.bbox_ioa
|
## ::: ultralytics.utils.metrics.bbox_ioa
|
||||||
<br><br>
|
<br><br>
|
||||||
|
@ -57,10 +57,6 @@ keywords: Ultralytics YOLO, Utility Operations, segment2box, make_divisible, cli
|
|||||||
## ::: ultralytics.utils.ops.xyxy2xywhn
|
## ::: ultralytics.utils.ops.xyxy2xywhn
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
---
|
|
||||||
## ::: ultralytics.utils.ops.xyn2xy
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
---
|
---
|
||||||
## ::: ultralytics.utils.ops.xywh2ltwh
|
## ::: ultralytics.utils.ops.xywh2ltwh
|
||||||
<br><br>
|
<br><br>
|
||||||
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ultralytics.utils import ROOT
|
from ultralytics.utils import ROOT
|
||||||
|
from ultralytics.utils.torch_utils import init_seeds
|
||||||
|
|
||||||
TMP = (ROOT / '../tests/tmp').resolve() # temp directory for test files
|
TMP = (ROOT / '../tests/tmp').resolve() # temp directory for test files
|
||||||
|
|
||||||
@ -32,6 +33,7 @@ def pytest_sessionstart(session):
|
|||||||
"""
|
"""
|
||||||
Called after the 'Session' object has been created and before performing test collection.
|
Called after the 'Session' object has been created and before performing test collection.
|
||||||
"""
|
"""
|
||||||
|
init_seeds()
|
||||||
shutil.rmtree(TMP, ignore_errors=True) # delete any existing tests/tmp directory
|
shutil.rmtree(TMP, ignore_errors=True) # delete any existing tests/tmp directory
|
||||||
TMP.mkdir(parents=True, exist_ok=True) # create a new empty directory
|
TMP.mkdir(parents=True, exist_ok=True) # create a new empty directory
|
||||||
|
|
||||||
|
@ -128,7 +128,7 @@ def test_track_stream():
|
|||||||
|
|
||||||
def test_val():
|
def test_val():
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
model.val(data='coco8.yaml', imgsz=32)
|
model.val(data='coco8.yaml', imgsz=32, save_hybrid=True)
|
||||||
|
|
||||||
|
|
||||||
def test_train_scratch():
|
def test_train_scratch():
|
||||||
@ -348,9 +348,20 @@ def test_utils_downloads():
|
|||||||
|
|
||||||
|
|
||||||
def test_utils_ops():
|
def test_utils_ops():
|
||||||
from ultralytics.utils.ops import make_divisible
|
from ultralytics.utils.ops import (ltwh2xywh, ltwh2xyxy, make_divisible, xywh2ltwh, xywh2xyxy, xywhn2xyxy,
|
||||||
|
xywhr2xyxyxyxy, xyxy2ltwh, xyxy2xywh, xyxy2xywhn, xyxyxyxy2xywhr)
|
||||||
|
|
||||||
make_divisible(17, 8)
|
make_divisible(17, torch.tensor([8]))
|
||||||
|
|
||||||
|
boxes = torch.rand(10, 4) # xywh
|
||||||
|
torch.allclose(boxes, xyxy2xywh(xywh2xyxy(boxes)))
|
||||||
|
torch.allclose(boxes, xyxy2xywhn(xywhn2xyxy(boxes)))
|
||||||
|
torch.allclose(boxes, ltwh2xywh(xywh2ltwh(boxes)))
|
||||||
|
torch.allclose(boxes, xyxy2ltwh(ltwh2xyxy(boxes)))
|
||||||
|
|
||||||
|
boxes = torch.rand(10, 5) # xywhr for OBB
|
||||||
|
boxes[:, 4] = torch.randn(10) * 30
|
||||||
|
torch.allclose(boxes, xyxyxyxy2xywhr(xywhr2xyxyxyxy(boxes)), rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
def test_utils_files():
|
def test_utils_files():
|
||||||
@ -364,3 +375,42 @@ def test_utils_files():
|
|||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
with spaces_in_path(path) as new_path:
|
with spaces_in_path(path) as new_path:
|
||||||
print(new_path)
|
print(new_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nn_modules_conv():
|
||||||
|
from ultralytics.nn.modules.conv import CBAM, Conv2, ConvTranspose, DWConvTranspose2d, Focus
|
||||||
|
|
||||||
|
c1, c2 = 8, 16 # input and output channels
|
||||||
|
x = torch.zeros(4, c1, 10, 10) # BCHW
|
||||||
|
|
||||||
|
# Run all modules not otherwise covered in tests
|
||||||
|
DWConvTranspose2d(c1, c2)(x)
|
||||||
|
ConvTranspose(c1, c2)(x)
|
||||||
|
Focus(c1, c2)(x)
|
||||||
|
CBAM(c1)(x)
|
||||||
|
|
||||||
|
# Fuse ops
|
||||||
|
m = Conv2(c1, c2)
|
||||||
|
m.fuse_convs()
|
||||||
|
m(x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nn_modules_block():
|
||||||
|
from ultralytics.nn.modules.block import C1, C3TR, BottleneckCSP, C3Ghost, C3x
|
||||||
|
|
||||||
|
c1, c2 = 8, 16 # input and output channels
|
||||||
|
x = torch.zeros(4, c1, 10, 10) # BCHW
|
||||||
|
|
||||||
|
# Run all modules not otherwise covered in tests
|
||||||
|
C1(c1, c2)(x)
|
||||||
|
C3x(c1, c2)(x)
|
||||||
|
C3TR(c1, c2)(x)
|
||||||
|
C3Ghost(c1, c2)(x)
|
||||||
|
BottleneckCSP(c1, c2)(x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_hub():
|
||||||
|
from ultralytics.hub import export_fmts_hub, logout
|
||||||
|
|
||||||
|
export_fmts_hub()
|
||||||
|
logout()
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
__version__ = '8.0.159'
|
__version__ = '8.0.159'
|
||||||
|
|
||||||
from ultralytics.hub import start
|
|
||||||
from ultralytics.models import RTDETR, SAM, YOLO
|
from ultralytics.models import RTDETR, SAM, YOLO
|
||||||
from ultralytics.models.fastsam import FastSAM
|
from ultralytics.models.fastsam import FastSAM
|
||||||
from ultralytics.models.nas import NAS
|
from ultralytics.models.nas import NAS
|
||||||
@ -10,4 +9,4 @@ from ultralytics.utils import SETTINGS as settings
|
|||||||
from ultralytics.utils.checks import check_yolo as checks
|
from ultralytics.utils.checks import check_yolo as checks
|
||||||
from ultralytics.utils.downloads import download
|
from ultralytics.utils.downloads import download
|
||||||
|
|
||||||
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'start', 'settings' # allow simpler import
|
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'settings' # allow simpler import
|
||||||
|
@ -5,7 +5,7 @@ import requests
|
|||||||
from ultralytics.data.utils import HUBDatasetStats
|
from ultralytics.data.utils import HUBDatasetStats
|
||||||
from ultralytics.hub.auth import Auth
|
from ultralytics.hub.auth import Auth
|
||||||
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
|
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
|
||||||
from ultralytics.utils import LOGGER, SETTINGS, USER_CONFIG_DIR, yaml_save
|
from ultralytics.utils import LOGGER, SETTINGS
|
||||||
|
|
||||||
|
|
||||||
def login(api_key=''):
|
def login(api_key=''):
|
||||||
@ -37,29 +37,10 @@ def logout():
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
SETTINGS['api_key'] = ''
|
SETTINGS['api_key'] = ''
|
||||||
yaml_save(USER_CONFIG_DIR / 'settings.yaml', SETTINGS)
|
SETTINGS.save()
|
||||||
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
|
LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
|
||||||
|
|
||||||
|
|
||||||
def start(key=''):
|
|
||||||
"""
|
|
||||||
Start training models with Ultralytics HUB (DEPRECATED).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key (str, optional): A string containing either the API key and model ID combination (apikey_modelid),
|
|
||||||
or the full model URL (https://hub.ultralytics.com/models/apikey_modelid).
|
|
||||||
"""
|
|
||||||
api_key, model_id = key.split('_')
|
|
||||||
LOGGER.warning(f"""
|
|
||||||
WARNING ⚠️ ultralytics.start() is deprecated after 8.0.60. Updated usage to train Ultralytics HUB models is:
|
|
||||||
|
|
||||||
from ultralytics import YOLO, hub
|
|
||||||
|
|
||||||
hub.login('{api_key}')
|
|
||||||
model = YOLO('{HUB_WEB_ROOT}/models/{model_id}')
|
|
||||||
model.train()""")
|
|
||||||
|
|
||||||
|
|
||||||
def reset_model(model_id=''):
|
def reset_model(model_id=''):
|
||||||
"""Reset a trained model to an untrained state."""
|
"""Reset a trained model to an untrained state."""
|
||||||
r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
|
r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
|
||||||
@ -117,7 +98,3 @@ def check_dataset(path='', task='detect'):
|
|||||||
"""
|
"""
|
||||||
HUBDatasetStats(path=path, task=task).get_json()
|
HUBDatasetStats(path=path, task=task).get_json()
|
||||||
LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.')
|
LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
start()
|
|
||||||
|
@ -73,8 +73,7 @@ class Auth:
|
|||||||
bool: True if authentication is successful, False otherwise.
|
bool: True if authentication is successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
header = self.get_auth_header()
|
if header := self.get_auth_header():
|
||||||
if header:
|
|
||||||
r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
|
r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
|
||||||
if not r.json().get('success', False):
|
if not r.json().get('success', False):
|
||||||
raise ConnectionError('Unable to authenticate.')
|
raise ConnectionError('Unable to authenticate.')
|
||||||
@ -117,23 +116,4 @@ class Auth:
|
|||||||
return {'authorization': f'Bearer {self.id_token}'}
|
return {'authorization': f'Bearer {self.id_token}'}
|
||||||
elif self.api_key:
|
elif self.api_key:
|
||||||
return {'x-api-key': self.api_key}
|
return {'x-api-key': self.api_key}
|
||||||
else:
|
# else returns None
|
||||||
return None
|
|
||||||
|
|
||||||
def get_state(self) -> bool:
|
|
||||||
"""
|
|
||||||
Get the authentication state.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if either id_token or API key is set, False otherwise.
|
|
||||||
"""
|
|
||||||
return self.id_token or self.api_key
|
|
||||||
|
|
||||||
def set_api_key(self, key: str):
|
|
||||||
"""
|
|
||||||
Set the API key for authentication.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key (str): The API key string.
|
|
||||||
"""
|
|
||||||
self.api_key = key
|
|
||||||
|
@ -30,11 +30,10 @@ class Sam(nn.Module):
|
|||||||
SAM predicts object masks from an image and input prompts.
|
SAM predicts object masks from an image and input prompts.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_encoder (ImageEncoderViT): The backbone used to encode the
|
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
|
||||||
image into image embeddings that allow for efficient mask prediction.
|
efficient mask prediction.
|
||||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
|
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
|
||||||
and encoded prompts.
|
|
||||||
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
||||||
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
||||||
"""
|
"""
|
||||||
@ -65,34 +64,25 @@ class Sam(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
|
batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
|
||||||
key can be excluded if it is not present.
|
key can be excluded if it is not present.
|
||||||
'image': The image as a torch tensor in 3xHxW format,
|
'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model.
|
||||||
already transformed for input to the model.
|
'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W).
|
||||||
'original_size': (tuple(int, int)) The original size of
|
'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already
|
||||||
the image before transformation, as (H, W).
|
transformed to the input frame of the model.
|
||||||
'point_coords': (torch.Tensor) Batched point prompts for
|
'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN.
|
||||||
this image, with shape BxNx2. Already transformed to the
|
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of
|
||||||
input frame of the model.
|
the model.
|
||||||
'point_labels': (torch.Tensor) Batched labels for point prompts,
|
'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW.
|
||||||
with shape BxN.
|
|
||||||
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
|
|
||||||
Already transformed to the input frame of the model.
|
|
||||||
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
|
|
||||||
in the form Bx1xHxW.
|
|
||||||
multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
|
multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
|
||||||
mask.
|
mask.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(list(dict)): A list over input images, where each element is as dictionary with the following keys.
|
(list(dict)): A list over input images, where each element is as dictionary with the following keys.
|
||||||
'masks': (torch.Tensor) Batched binary mask predictions,
|
'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of
|
||||||
with shape BxCxHxW, where B is the number of input prompts,
|
input prompts, C is determined by multimask_output, and (H, W) is the original size of the image.
|
||||||
C is determined by multimask_output, and (H, W) is the
|
'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC.
|
||||||
original size of the image.
|
'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed
|
||||||
'iou_predictions': (torch.Tensor) The model's predictions
|
as mask input to subsequent iterations of prediction.
|
||||||
of mask quality, in shape BxC.
|
|
||||||
'low_res_logits': (torch.Tensor) Low resolution logits with
|
|
||||||
shape BxCxHxW, where H=W=256. Can be passed as mask input
|
|
||||||
to subsequent iterations of prediction.
|
|
||||||
"""
|
"""
|
||||||
input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0)
|
input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0)
|
||||||
image_embeddings = self.image_encoder(input_images)
|
image_embeddings = self.image_encoder(input_images)
|
||||||
@ -137,16 +127,12 @@ class Sam(nn.Module):
|
|||||||
Remove padding and upscale masks to the original image size.
|
Remove padding and upscale masks to the original image size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
masks (torch.Tensor): Batched masks from the mask_decoder,
|
masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format.
|
||||||
in BxCxHxW format.
|
input_size (tuple(int, int)): The size of the model input image, in (H, W) format. Used to remove padding.
|
||||||
input_size (tuple(int, int)): The size of the image input to the
|
original_size (tuple(int, int)): The original image size before resizing for input to the model, in (H, W).
|
||||||
model, in (H, W) format. Used to remove padding.
|
|
||||||
original_size (tuple(int, int)): The original size of the image
|
|
||||||
before resizing for input to the model, in (H, W) format.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
|
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size.
|
||||||
is given by original_size.
|
|
||||||
"""
|
"""
|
||||||
masks = F.interpolate(
|
masks = F.interpolate(
|
||||||
masks,
|
masks,
|
||||||
|
@ -9,7 +9,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
__all__ = ('Conv', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
|
__all__ = ('Conv', 'Conv2', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
|
||||||
'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')
|
'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'RepConv')
|
||||||
|
|
||||||
|
|
||||||
@ -54,6 +54,10 @@ class Conv2(Conv):
|
|||||||
"""Apply convolution, batch normalization and activation to input tensor."""
|
"""Apply convolution, batch normalization and activation to input tensor."""
|
||||||
return self.act(self.bn(self.conv(x) + self.cv2(x)))
|
return self.act(self.bn(self.conv(x) + self.cv2(x)))
|
||||||
|
|
||||||
|
def forward_fuse(self, x):
|
||||||
|
"""Apply fused convolution, batch normalization and activation to input tensor."""
|
||||||
|
return self.act(self.bn(self.conv(x)))
|
||||||
|
|
||||||
def fuse_convs(self):
|
def fuse_convs(self):
|
||||||
"""Fuse parallel convolutions."""
|
"""Fuse parallel convolutions."""
|
||||||
w = torch.zeros_like(self.conv.weight.data)
|
w = torch.zeros_like(self.conv.weight.data)
|
||||||
@ -61,6 +65,7 @@ class Conv2(Conv):
|
|||||||
w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone()
|
w[:, :, i[0]:i[0] + 1, i[1]:i[1] + 1] = self.cv2.weight.data.clone()
|
||||||
self.conv.weight.data += w
|
self.conv.weight.data += w
|
||||||
self.__delattr__('cv2')
|
self.__delattr__('cv2')
|
||||||
|
self.forward = self.forward_fuse
|
||||||
|
|
||||||
|
|
||||||
class LightConv(nn.Module):
|
class LightConv(nn.Module):
|
||||||
|
@ -6,20 +6,13 @@ import scipy.linalg
|
|||||||
|
|
||||||
class KalmanFilterXYAH:
|
class KalmanFilterXYAH:
|
||||||
"""
|
"""
|
||||||
For bytetrack
|
For bytetrack. A simple Kalman filter for tracking bounding boxes in image space.
|
||||||
A simple Kalman filter for tracking bounding boxes in image space.
|
|
||||||
|
|
||||||
The 8-dimensional state space
|
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y),
|
||||||
|
aspect ratio a, height h, and their respective velocities.
|
||||||
x, y, a, h, vx, vy, va, vh
|
|
||||||
|
|
||||||
contains the bounding box center position (x, y), aspect ratio a, height h,
|
|
||||||
and their respective velocities.
|
|
||||||
|
|
||||||
Object motion follows a constant velocity model. The bounding box location
|
|
||||||
(x, y, a, h) is taken as direct observation of the state space (linear
|
|
||||||
observation model).
|
|
||||||
|
|
||||||
|
Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct
|
||||||
|
observation of the state space (linear observation model).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -32,14 +25,14 @@ class KalmanFilterXYAH:
|
|||||||
self._motion_mat[i, ndim + i] = dt
|
self._motion_mat[i, ndim + i] = dt
|
||||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||||
|
|
||||||
# Motion and observation uncertainty are chosen relative to the current
|
# Motion and observation uncertainty are chosen relative to the current state estimate. These weights control
|
||||||
# state estimate. These weights control the amount of uncertainty in
|
# the amount of uncertainty in the model. This is a bit hacky.
|
||||||
# the model. This is a bit hacky.
|
|
||||||
self._std_weight_position = 1. / 20
|
self._std_weight_position = 1. / 20
|
||||||
self._std_weight_velocity = 1. / 160
|
self._std_weight_velocity = 1. / 160
|
||||||
|
|
||||||
def initiate(self, measurement):
|
def initiate(self, measurement):
|
||||||
"""Create track from unassociated measurement.
|
"""
|
||||||
|
Create track from unassociated measurement.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -53,7 +46,6 @@ class KalmanFilterXYAH:
|
|||||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||||
dimensional) of the new track. Unobserved velocities are initialized
|
dimensional) of the new track. Unobserved velocities are initialized
|
||||||
to 0 mean.
|
to 0 mean.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
mean_pos = measurement
|
mean_pos = measurement
|
||||||
mean_vel = np.zeros_like(mean_pos)
|
mean_vel = np.zeros_like(mean_pos)
|
||||||
@ -67,23 +59,21 @@ class KalmanFilterXYAH:
|
|||||||
return mean, covariance
|
return mean, covariance
|
||||||
|
|
||||||
def predict(self, mean, covariance):
|
def predict(self, mean, covariance):
|
||||||
"""Run Kalman filter prediction step.
|
"""
|
||||||
|
Run Kalman filter prediction step.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
mean : ndarray
|
mean : ndarray
|
||||||
The 8 dimensional mean vector of the object state at the previous
|
The 8 dimensional mean vector of the object state at the previous time step.
|
||||||
time step.
|
|
||||||
covariance : ndarray
|
covariance : ndarray
|
||||||
The 8x8 dimensional covariance matrix of the object state at the
|
The 8x8 dimensional covariance matrix of the object state at the previous time step.
|
||||||
previous time step.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
(ndarray, ndarray)
|
(ndarray, ndarray)
|
||||||
Returns the mean vector and covariance matrix of the predicted
|
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
|
||||||
state. Unobserved velocities are initialized to 0 mean.
|
initialized to 0 mean.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
std_pos = [
|
std_pos = [
|
||||||
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2,
|
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2,
|
||||||
@ -100,7 +90,8 @@ class KalmanFilterXYAH:
|
|||||||
return mean, covariance
|
return mean, covariance
|
||||||
|
|
||||||
def project(self, mean, covariance):
|
def project(self, mean, covariance):
|
||||||
"""Project state distribution to measurement space.
|
"""
|
||||||
|
Project state distribution to measurement space.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -112,9 +103,7 @@ class KalmanFilterXYAH:
|
|||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
(ndarray, ndarray)
|
(ndarray, ndarray)
|
||||||
Returns the projected mean and covariance matrix of the given state
|
Returns the projected mean and covariance matrix of the given state estimate.
|
||||||
estimate.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
std = [
|
std = [
|
||||||
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1,
|
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1,
|
||||||
@ -126,20 +115,21 @@ class KalmanFilterXYAH:
|
|||||||
return mean, covariance + innovation_cov
|
return mean, covariance + innovation_cov
|
||||||
|
|
||||||
def multi_predict(self, mean, covariance):
|
def multi_predict(self, mean, covariance):
|
||||||
"""Run Kalman filter prediction step (Vectorized version).
|
"""
|
||||||
|
Run Kalman filter prediction step (Vectorized version).
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
mean : ndarray
|
mean : ndarray
|
||||||
The Nx8 dimensional mean matrix of the object states at the previous
|
The Nx8 dimensional mean matrix of the object states at the previous time step.
|
||||||
time step.
|
|
||||||
covariance : ndarray
|
covariance : ndarray
|
||||||
The Nx8x8 dimensional covariance matrix of the object states at the
|
The Nx8x8 dimensional covariance matrix of the object states at the previous time step.
|
||||||
previous time step.
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
(ndarray, ndarray)
|
(ndarray, ndarray)
|
||||||
Returns the mean vector and covariance matrix of the predicted
|
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
|
||||||
state. Unobserved velocities are initialized to 0 mean.
|
initialized to 0 mean.
|
||||||
"""
|
"""
|
||||||
std_pos = [
|
std_pos = [
|
||||||
self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3],
|
self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3],
|
||||||
@ -159,7 +149,8 @@ class KalmanFilterXYAH:
|
|||||||
return mean, covariance
|
return mean, covariance
|
||||||
|
|
||||||
def update(self, mean, covariance, measurement):
|
def update(self, mean, covariance, measurement):
|
||||||
"""Run Kalman filter correction step.
|
"""
|
||||||
|
Run Kalman filter correction step.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -168,15 +159,13 @@ class KalmanFilterXYAH:
|
|||||||
covariance : ndarray
|
covariance : ndarray
|
||||||
The state's covariance matrix (8x8 dimensional).
|
The state's covariance matrix (8x8 dimensional).
|
||||||
measurement : ndarray
|
measurement : ndarray
|
||||||
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center position, a the aspect
|
||||||
is the center position, a the aspect ratio, and h the height of the
|
ratio, and h the height of the bounding box.
|
||||||
bounding box.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
(ndarray, ndarray)
|
(ndarray, ndarray)
|
||||||
Returns the measurement-corrected state distribution.
|
Returns the measurement-corrected state distribution.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
projected_mean, projected_cov = self.project(mean, covariance)
|
projected_mean, projected_cov = self.project(mean, covariance)
|
||||||
|
|
||||||
@ -191,10 +180,11 @@ class KalmanFilterXYAH:
|
|||||||
return new_mean, new_covariance
|
return new_mean, new_covariance
|
||||||
|
|
||||||
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
|
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
|
||||||
"""Compute gating distance between state distribution and measurements.
|
"""
|
||||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
Compute gating distance between state distribution and measurements. A suitable distance threshold can be
|
||||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of
|
||||||
freedom, otherwise 2.
|
freedom, otherwise 2.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
mean : ndarray
|
mean : ndarray
|
||||||
@ -202,18 +192,16 @@ class KalmanFilterXYAH:
|
|||||||
covariance : ndarray
|
covariance : ndarray
|
||||||
Covariance of the state distribution (8x8 dimensional).
|
Covariance of the state distribution (8x8 dimensional).
|
||||||
measurements : ndarray
|
measurements : ndarray
|
||||||
An Nx4 dimensional matrix of N measurements, each in
|
An Nx4 dimensional matrix of N measurements, each in format (x, y, a, h) where (x, y) is the bounding box
|
||||||
format (x, y, a, h) where (x, y) is the bounding box center
|
center position, a the aspect ratio, and h the height.
|
||||||
position, a the aspect ratio, and h the height.
|
|
||||||
only_position : Optional[bool]
|
only_position : Optional[bool]
|
||||||
If True, distance computation is done with respect to the bounding
|
If True, distance computation is done with respect to the bounding box center position only.
|
||||||
box center position only.
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
ndarray
|
ndarray
|
||||||
Returns an array of length N, where the i-th element contains the
|
Returns an array of length N, where the i-th element contains the squared Mahalanobis distance between
|
||||||
squared Mahalanobis distance between (mean, covariance) and
|
(mean, covariance) and `measurements[i]`.
|
||||||
`measurements[i]`.
|
|
||||||
"""
|
"""
|
||||||
mean, covariance = self.project(mean, covariance)
|
mean, covariance = self.project(mean, covariance)
|
||||||
if only_position:
|
if only_position:
|
||||||
@ -233,38 +221,29 @@ class KalmanFilterXYAH:
|
|||||||
|
|
||||||
class KalmanFilterXYWH(KalmanFilterXYAH):
|
class KalmanFilterXYWH(KalmanFilterXYAH):
|
||||||
"""
|
"""
|
||||||
For BoT-SORT
|
For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space.
|
||||||
A simple Kalman filter for tracking bounding boxes in image space.
|
|
||||||
|
|
||||||
The 8-dimensional state space
|
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y),
|
||||||
|
width w, height h, and their respective velocities.
|
||||||
x, y, w, h, vx, vy, vw, vh
|
|
||||||
|
|
||||||
contains the bounding box center position (x, y), width w, height h,
|
|
||||||
and their respective velocities.
|
|
||||||
|
|
||||||
Object motion follows a constant velocity model. The bounding box location
|
|
||||||
(x, y, w, h) is taken as direct observation of the state space (linear
|
|
||||||
observation model).
|
|
||||||
|
|
||||||
|
Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct
|
||||||
|
observation of the state space (linear observation model).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def initiate(self, measurement):
|
def initiate(self, measurement):
|
||||||
"""Create track from unassociated measurement.
|
"""
|
||||||
|
Create track from unassociated measurement.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
measurement : ndarray
|
measurement : ndarray
|
||||||
Bounding box coordinates (x, y, w, h) with center position (x, y),
|
Bounding box coordinates (x, y, w, h) with center position (x, y), width w, and height h.
|
||||||
width w, and height h.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
(ndarray, ndarray)
|
(ndarray, ndarray)
|
||||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of the new track.
|
||||||
dimensional) of the new track. Unobserved velocities are initialized
|
Unobserved velocities are initialized to 0 mean.
|
||||||
to 0 mean.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
mean_pos = measurement
|
mean_pos = measurement
|
||||||
mean_vel = np.zeros_like(mean_pos)
|
mean_vel = np.zeros_like(mean_pos)
|
||||||
@ -279,23 +258,21 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||||||
return mean, covariance
|
return mean, covariance
|
||||||
|
|
||||||
def predict(self, mean, covariance):
|
def predict(self, mean, covariance):
|
||||||
"""Run Kalman filter prediction step.
|
"""
|
||||||
|
Run Kalman filter prediction step.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
mean : ndarray
|
mean : ndarray
|
||||||
The 8 dimensional mean vector of the object state at the previous
|
The 8 dimensional mean vector of the object state at the previous time step.
|
||||||
time step.
|
|
||||||
covariance : ndarray
|
covariance : ndarray
|
||||||
The 8x8 dimensional covariance matrix of the object state at the
|
The 8x8 dimensional covariance matrix of the object state at the previous time step.
|
||||||
previous time step.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
(ndarray, ndarray)
|
(ndarray, ndarray)
|
||||||
Returns the mean vector and covariance matrix of the predicted
|
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
|
||||||
state. Unobserved velocities are initialized to 0 mean.
|
initialized to 0 mean.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
std_pos = [
|
std_pos = [
|
||||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
|
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
|
||||||
@ -311,7 +288,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||||||
return mean, covariance
|
return mean, covariance
|
||||||
|
|
||||||
def project(self, mean, covariance):
|
def project(self, mean, covariance):
|
||||||
"""Project state distribution to measurement space.
|
"""
|
||||||
|
Project state distribution to measurement space.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -323,9 +301,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
(ndarray, ndarray)
|
(ndarray, ndarray)
|
||||||
Returns the projected mean and covariance matrix of the given state
|
Returns the projected mean and covariance matrix of the given state estimate.
|
||||||
estimate.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
std = [
|
std = [
|
||||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
|
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
|
||||||
@ -337,20 +313,21 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||||||
return mean, covariance + innovation_cov
|
return mean, covariance + innovation_cov
|
||||||
|
|
||||||
def multi_predict(self, mean, covariance):
|
def multi_predict(self, mean, covariance):
|
||||||
"""Run Kalman filter prediction step (Vectorized version).
|
"""
|
||||||
|
Run Kalman filter prediction step (Vectorized version).
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
mean : ndarray
|
mean : ndarray
|
||||||
The Nx8 dimensional mean matrix of the object states at the previous
|
The Nx8 dimensional mean matrix of the object states at the previous time step.
|
||||||
time step.
|
|
||||||
covariance : ndarray
|
covariance : ndarray
|
||||||
The Nx8x8 dimensional covariance matrix of the object states at the
|
The Nx8x8 dimensional covariance matrix of the object states at the previous time step.
|
||||||
previous time step.
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
(ndarray, ndarray)
|
(ndarray, ndarray)
|
||||||
Returns the mean vector and covariance matrix of the predicted
|
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
|
||||||
state. Unobserved velocities are initialized to 0 mean.
|
initialized to 0 mean.
|
||||||
"""
|
"""
|
||||||
std_pos = [
|
std_pos = [
|
||||||
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3],
|
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3],
|
||||||
@ -370,7 +347,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||||||
return mean, covariance
|
return mean, covariance
|
||||||
|
|
||||||
def update(self, mean, covariance, measurement):
|
def update(self, mean, covariance, measurement):
|
||||||
"""Run Kalman filter correction step.
|
"""
|
||||||
|
Run Kalman filter correction step.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -379,14 +357,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||||||
covariance : ndarray
|
covariance : ndarray
|
||||||
The state's covariance matrix (8x8 dimensional).
|
The state's covariance matrix (8x8 dimensional).
|
||||||
measurement : ndarray
|
measurement : ndarray
|
||||||
The 4 dimensional measurement vector (x, y, w, h), where (x, y)
|
The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center position, w the width,
|
||||||
is the center position, w the width, and h the height of the
|
and h the height of the bounding box.
|
||||||
bounding box.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
(ndarray, ndarray)
|
(ndarray, ndarray)
|
||||||
Returns the measurement-corrected state distribution.
|
Returns the measurement-corrected state distribution.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return super().update(mean, covariance, measurement)
|
return super().update(mean, covariance, measurement)
|
||||||
|
@ -212,21 +212,18 @@ def get_google_drive_file_info(link):
|
|||||||
"""
|
"""
|
||||||
file_id = link.split('/d/')[1].split('/view')[0]
|
file_id = link.split('/d/')[1].split('/view')[0]
|
||||||
drive_url = f'https://drive.google.com/uc?export=download&id={file_id}'
|
drive_url = f'https://drive.google.com/uc?export=download&id={file_id}'
|
||||||
|
filename = None
|
||||||
|
|
||||||
# Start session
|
# Start session
|
||||||
filename = None
|
|
||||||
with requests.Session() as session:
|
with requests.Session() as session:
|
||||||
response = session.get(drive_url, stream=True)
|
response = session.get(drive_url, stream=True)
|
||||||
if 'quota exceeded' in str(response.content.lower()):
|
if 'quota exceeded' in str(response.content.lower()):
|
||||||
raise ConnectionError(
|
raise ConnectionError(
|
||||||
emojis(f'❌ Google Drive file download quota exceeded. '
|
emojis(f'❌ Google Drive file download quota exceeded. '
|
||||||
f'Please try again later or download this file manually at {link}.'))
|
f'Please try again later or download this file manually at {link}.'))
|
||||||
token = None
|
for k, v in response.cookies.items():
|
||||||
for key, value in response.cookies.items():
|
if k.startswith('download_warning'):
|
||||||
if key.startswith('download_warning'):
|
drive_url += f'&confirm={v}' # v is token
|
||||||
token = value
|
|
||||||
if token:
|
|
||||||
drive_url = f'https://drive.google.com/uc?export=download&confirm={token}&id={file_id}'
|
|
||||||
cd = response.headers.get('content-disposition')
|
cd = response.headers.get('content-disposition')
|
||||||
if cd:
|
if cd:
|
||||||
filename = re.findall('filename="(.+)"', cd)[0]
|
filename = re.findall('filename="(.+)"', cd)[0]
|
||||||
|
@ -15,12 +15,6 @@ from ultralytics.utils import LOGGER, SimpleClass, TryExcept, plt_settings
|
|||||||
OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
|
OKS_SIGMA = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89]) / 10.0
|
||||||
|
|
||||||
|
|
||||||
# Boxes
|
|
||||||
def box_area(box):
|
|
||||||
"""Return box area, where box shape is xyxy(4,n)."""
|
|
||||||
return (box[2] - box[0]) * (box[3] - box[1])
|
|
||||||
|
|
||||||
|
|
||||||
def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
def bbox_ioa(box1, box2, iou=False, eps=1e-7):
|
||||||
"""
|
"""
|
||||||
Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
|
Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format.
|
||||||
@ -869,11 +863,6 @@ class PoseMetrics(SegmentMetrics):
|
|||||||
self.pose = Metric()
|
self.pose = Metric()
|
||||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
"""Raises an AttributeError if an invalid attribute is accessed."""
|
|
||||||
name = self.__class__.__name__
|
|
||||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
|
||||||
|
|
||||||
def process(self, tp_b, tp_p, conf, pred_cls, target_cls):
|
def process(self, tp_b, tp_p, conf, pred_cls, target_cls):
|
||||||
"""
|
"""
|
||||||
Processes the detection and pose metrics over the given set of predictions.
|
Processes the detection and pose metrics over the given set of predictions.
|
||||||
|
@ -13,8 +13,6 @@ import torchvision
|
|||||||
|
|
||||||
from ultralytics.utils import LOGGER
|
from ultralytics.utils import LOGGER
|
||||||
|
|
||||||
from .metrics import box_iou
|
|
||||||
|
|
||||||
|
|
||||||
class Profile(contextlib.ContextDecorator):
|
class Profile(contextlib.ContextDecorator):
|
||||||
"""
|
"""
|
||||||
@ -32,23 +30,17 @@ class Profile(contextlib.ContextDecorator):
|
|||||||
self.cuda = torch.cuda.is_available()
|
self.cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
"""
|
"""Start timing."""
|
||||||
Start timing.
|
|
||||||
"""
|
|
||||||
self.start = self.time()
|
self.start = self.time()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback): # noqa
|
def __exit__(self, type, value, traceback): # noqa
|
||||||
"""
|
"""Stop timing."""
|
||||||
Stop timing.
|
|
||||||
"""
|
|
||||||
self.dt = self.time() - self.start # delta-time
|
self.dt = self.time() - self.start # delta-time
|
||||||
self.t += self.dt # accumulate dt
|
self.t += self.dt # accumulate dt
|
||||||
|
|
||||||
def time(self):
|
def time(self):
|
||||||
"""
|
"""Get current time."""
|
||||||
Get current time.
|
|
||||||
"""
|
|
||||||
if self.cuda:
|
if self.cuda:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
return time.time()
|
return time.time()
|
||||||
@ -56,15 +48,15 @@ class Profile(contextlib.ContextDecorator):
|
|||||||
|
|
||||||
def segment2box(segment, width=640, height=640):
|
def segment2box(segment, width=640, height=640):
|
||||||
"""
|
"""
|
||||||
Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
segment (torch.Tensor): the segment label
|
segment (torch.Tensor): the segment label
|
||||||
width (int): the width of the image. Defaults to 640
|
width (int): the width of the image. Defaults to 640
|
||||||
height (int): The height of the image. Defaults to 640
|
height (int): The height of the image. Defaults to 640
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(np.ndarray): the minimum and maximum x and y values of the segment.
|
(np.ndarray): the minimum and maximum x and y values of the segment.
|
||||||
"""
|
"""
|
||||||
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
||||||
x, y = segment.T # segment xy
|
x, y = segment.T # segment xy
|
||||||
@ -80,16 +72,16 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True):
|
|||||||
(img1_shape) to the shape of a different image (img0_shape).
|
(img1_shape) to the shape of a different image (img0_shape).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
|
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
|
||||||
boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
|
boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
|
||||||
img0_shape (tuple): the shape of the target image, in the format of (height, width).
|
img0_shape (tuple): the shape of the target image, in the format of (height, width).
|
||||||
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
|
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
|
||||||
calculated based on the size difference between the two images.
|
calculated based on the size difference between the two images.
|
||||||
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
|
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
|
||||||
rescaling.
|
rescaling.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
|
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
|
||||||
"""
|
"""
|
||||||
if ratio_pad is None: # calculate from img0_shape
|
if ratio_pad is None: # calculate from img0_shape
|
||||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||||
@ -186,9 +178,7 @@ def non_max_suppression(
|
|||||||
# Settings
|
# Settings
|
||||||
# min_wh = 2 # (pixels) minimum box width and height
|
# min_wh = 2 # (pixels) minimum box width and height
|
||||||
time_limit = 0.5 + max_time_img * bs # seconds to quit after
|
time_limit = 0.5 + max_time_img * bs # seconds to quit after
|
||||||
redundant = True # require redundant detections
|
|
||||||
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||||
merge = False # use merge-NMS
|
|
||||||
|
|
||||||
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
||||||
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
||||||
@ -226,10 +216,6 @@ def non_max_suppression(
|
|||||||
if classes is not None:
|
if classes is not None:
|
||||||
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
||||||
|
|
||||||
# Apply finite constraint
|
|
||||||
# if not torch.isfinite(x).all():
|
|
||||||
# x = x[torch.isfinite(x).all(1)]
|
|
||||||
|
|
||||||
# Check shape
|
# Check shape
|
||||||
n = x.shape[0] # number of boxes
|
n = x.shape[0] # number of boxes
|
||||||
if not n: # no boxes
|
if not n: # no boxes
|
||||||
@ -242,13 +228,18 @@ def non_max_suppression(
|
|||||||
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
||||||
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
||||||
i = i[:max_det] # limit detections
|
i = i[:max_det] # limit detections
|
||||||
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
|
||||||
# Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
# # Experimental
|
||||||
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
# merge = False # use merge-NMS
|
||||||
weights = iou * scores[None] # box weights
|
# if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
||||||
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
# # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
||||||
if redundant:
|
# from .metrics import box_iou
|
||||||
i = i[iou.sum(1) > 1] # require redundancy
|
# iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
||||||
|
# weights = iou * scores[None] # box weights
|
||||||
|
# x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
||||||
|
# redundant = True # require redundant detections
|
||||||
|
# if redundant:
|
||||||
|
# i = i[iou.sum(1) > 1] # require redundancy
|
||||||
|
|
||||||
output[xi] = x[i]
|
output[xi] = x[i]
|
||||||
if mps:
|
if mps:
|
||||||
@ -262,8 +253,7 @@ def non_max_suppression(
|
|||||||
|
|
||||||
def clip_boxes(boxes, shape):
|
def clip_boxes(boxes, shape):
|
||||||
"""
|
"""
|
||||||
It takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the
|
Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
|
||||||
shape
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
boxes (torch.Tensor): the bounding boxes to clip
|
boxes (torch.Tensor): the bounding boxes to clip
|
||||||
@ -303,12 +293,12 @@ def scale_image(masks, im0_shape, ratio_pad=None):
|
|||||||
Takes a mask, and resizes it to the original image size
|
Takes a mask, and resizes it to the original image size
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
|
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
|
||||||
im0_shape (tuple): the original image shape
|
im0_shape (tuple): the original image shape
|
||||||
ratio_pad (tuple): the ratio of the padding to the original image.
|
ratio_pad (tuple): the ratio of the padding to the original image.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
masks (torch.Tensor): The masks that are being returned.
|
masks (torch.Tensor): The masks that are being returned.
|
||||||
"""
|
"""
|
||||||
# Rescale coordinates (xyxy) from im1_shape to im0_shape
|
# Rescale coordinates (xyxy) from im1_shape to im0_shape
|
||||||
im1_shape = masks.shape
|
im1_shape = masks.shape
|
||||||
@ -340,6 +330,7 @@ def xyxy2xywh(x):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
|
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
||||||
"""
|
"""
|
||||||
@ -359,6 +350,7 @@ def xywh2xyxy(x):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
|
x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
|
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
|
||||||
"""
|
"""
|
||||||
@ -407,6 +399,7 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
|||||||
h (int): The height of the image. Defaults to 640
|
h (int): The height of the image. Defaults to 640
|
||||||
clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False
|
clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False
|
||||||
eps (float): The minimum value of the box's width and height. Defaults to 0.0
|
eps (float): The minimum value of the box's width and height. Defaults to 0.0
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
|
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
|
||||||
"""
|
"""
|
||||||
@ -421,31 +414,13 @@ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
|
|
||||||
"""
|
|
||||||
Convert normalized coordinates to pixel coordinates of shape (n,2)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (np.ndarray | torch.Tensor): The input tensor of normalized bounding box coordinates
|
|
||||||
w (int): The width of the image. Defaults to 640
|
|
||||||
h (int): The height of the image. Defaults to 640
|
|
||||||
padw (int): The width of the padding. Defaults to 0
|
|
||||||
padh (int): The height of the padding. Defaults to 0
|
|
||||||
Returns:
|
|
||||||
y (np.ndarray | torch.Tensor): The x and y coordinates of the top left corner of the bounding box
|
|
||||||
"""
|
|
||||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
|
||||||
y[..., 0] = w * x[..., 0] + padw # top left x
|
|
||||||
y[..., 1] = h * x[..., 1] + padh # top left y
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
def xywh2ltwh(x):
|
def xywh2ltwh(x):
|
||||||
"""
|
"""
|
||||||
Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
|
Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
|
x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
|
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
|
||||||
"""
|
"""
|
||||||
@ -460,9 +435,10 @@ def xyxy2ltwh(x):
|
|||||||
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right
|
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
|
x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
|
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
|
||||||
"""
|
"""
|
||||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||||
y[..., 2] = x[..., 2] - x[..., 0] # width
|
y[..., 2] = x[..., 2] - x[..., 0] # width
|
||||||
@ -475,7 +451,10 @@ def ltwh2xywh(x):
|
|||||||
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
|
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): the input tensor
|
x (torch.Tensor): the input tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format.
|
||||||
"""
|
"""
|
||||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||||
y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
|
y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
|
||||||
@ -493,14 +472,8 @@ def xyxyxyxy2xywhr(corners):
|
|||||||
Returns:
|
Returns:
|
||||||
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
|
(numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
|
||||||
"""
|
"""
|
||||||
if isinstance(corners, torch.Tensor):
|
is_numpy = isinstance(corners, np.ndarray)
|
||||||
is_numpy = False
|
atan2, sqrt = (np.arctan2, np.sqrt) if is_numpy else (torch.atan2, torch.sqrt)
|
||||||
atan2 = torch.atan2
|
|
||||||
sqrt = torch.sqrt
|
|
||||||
else:
|
|
||||||
is_numpy = True
|
|
||||||
atan2 = np.arctan2
|
|
||||||
sqrt = np.sqrt
|
|
||||||
|
|
||||||
x1, y1, x2, y2, x3, y3, x4, y4 = corners.T
|
x1, y1, x2, y2, x3, y3, x4, y4 = corners.T
|
||||||
cx = (x1 + x3) / 2
|
cx = (x1 + x3) / 2
|
||||||
@ -527,14 +500,8 @@ def xywhr2xyxyxyxy(center):
|
|||||||
Returns:
|
Returns:
|
||||||
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 8).
|
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 8).
|
||||||
"""
|
"""
|
||||||
if isinstance(center, torch.Tensor):
|
is_numpy = isinstance(center, np.ndarray)
|
||||||
is_numpy = False
|
cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin)
|
||||||
cos = torch.cos
|
|
||||||
sin = torch.sin
|
|
||||||
else:
|
|
||||||
is_numpy = True
|
|
||||||
cos = np.cos
|
|
||||||
sin = np.sin
|
|
||||||
|
|
||||||
cx, cy, w, h, rotation = center.T
|
cx, cy, w, h, rotation = center.T
|
||||||
rotation *= math.pi / 180.0 # degrees to radians
|
rotation *= math.pi / 180.0 # degrees to radians
|
||||||
@ -567,10 +534,10 @@ def ltwh2xyxy(x):
|
|||||||
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (np.ndarray | torch.Tensor): the input image
|
x (np.ndarray | torch.Tensor): the input image
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
|
y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
|
||||||
"""
|
"""
|
||||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||||
y[..., 2] = x[..., 2] + x[..., 0] # width
|
y[..., 2] = x[..., 2] + x[..., 0] # width
|
||||||
@ -583,10 +550,10 @@ def segments2boxes(segments):
|
|||||||
It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
|
It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
|
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(np.ndarray): the xywh coordinates of the bounding boxes.
|
(np.ndarray): the xywh coordinates of the bounding boxes.
|
||||||
"""
|
"""
|
||||||
boxes = []
|
boxes = []
|
||||||
for s in segments:
|
for s in segments:
|
||||||
@ -600,11 +567,11 @@ def resample_segments(segments, n=1000):
|
|||||||
Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
|
Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
|
segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
|
||||||
n (int): number of points to resample the segment to. Defaults to 1000
|
n (int): number of points to resample the segment to. Defaults to 1000
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
segments (list): the resampled segments.
|
segments (list): the resampled segments.
|
||||||
"""
|
"""
|
||||||
for i, s in enumerate(segments):
|
for i, s in enumerate(segments):
|
||||||
s = np.concatenate((s, s[0:1, :]), axis=0)
|
s = np.concatenate((s, s[0:1, :]), axis=0)
|
||||||
@ -617,14 +584,14 @@ def resample_segments(segments, n=1000):
|
|||||||
|
|
||||||
def crop_mask(masks, boxes):
|
def crop_mask(masks, boxes):
|
||||||
"""
|
"""
|
||||||
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box
|
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
masks (torch.Tensor): [n, h, w] tensor of masks
|
masks (torch.Tensor): [n, h, w] tensor of masks
|
||||||
boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
|
boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(torch.Tensor): The masks are being cropped to the bounding box.
|
(torch.Tensor): The masks are being cropped to the bounding box.
|
||||||
"""
|
"""
|
||||||
n, h, w = masks.shape
|
n, h, w = masks.shape
|
||||||
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
|
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
|
||||||
@ -636,17 +603,17 @@ def crop_mask(masks, boxes):
|
|||||||
|
|
||||||
def process_mask_upsample(protos, masks_in, bboxes, shape):
|
def process_mask_upsample(protos, masks_in, bboxes, shape):
|
||||||
"""
|
"""
|
||||||
It takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
|
Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher
|
||||||
quality but is slower.
|
quality but is slower.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||||
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
||||||
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
||||||
shape (tuple): the size of the input image (h,w)
|
shape (tuple): the size of the input image (h,w)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(torch.Tensor): The upsampled masks.
|
(torch.Tensor): The upsampled masks.
|
||||||
"""
|
"""
|
||||||
c, mh, mw = protos.shape # CHW
|
c, mh, mw = protos.shape # CHW
|
||||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
||||||
@ -692,13 +659,13 @@ def process_mask_native(protos, masks_in, bboxes, shape):
|
|||||||
It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
|
It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||||
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
||||||
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
||||||
shape (tuple): the size of the input image (h,w)
|
shape (tuple): the size of the input image (h,w)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
masks (torch.Tensor): The returned masks with dimensions [h, w, n]
|
masks (torch.Tensor): The returned masks with dimensions [h, w, n]
|
||||||
"""
|
"""
|
||||||
c, mh, mw = protos.shape # CHW
|
c, mh, mw = protos.shape # CHW
|
||||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
||||||
@ -733,19 +700,19 @@ def scale_masks(masks, shape, padding=True):
|
|||||||
|
|
||||||
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
|
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
|
||||||
"""
|
"""
|
||||||
Rescale segment coordinates (xyxy) from img1_shape to img0_shape
|
Rescale segment coordinates (xy) from img1_shape to img0_shape
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
img1_shape (tuple): The shape of the image that the coords are from.
|
img1_shape (tuple): The shape of the image that the coords are from.
|
||||||
coords (torch.Tensor): the coords to be scaled
|
coords (torch.Tensor): the coords to be scaled of shape n,2.
|
||||||
img0_shape (tuple): the shape of the image that the segmentation is being applied to
|
img0_shape (tuple): the shape of the image that the segmentation is being applied to.
|
||||||
ratio_pad (tuple): the ratio of the image size to the padded image size.
|
ratio_pad (tuple): the ratio of the image size to the padded image size.
|
||||||
normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False
|
normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False.
|
||||||
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
|
padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
|
||||||
rescaling.
|
rescaling.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
coords (torch.Tensor): the segmented image.
|
coords (torch.Tensor): The scaled coordinates.
|
||||||
"""
|
"""
|
||||||
if ratio_pad is None: # calculate from img0_shape
|
if ratio_pad is None: # calculate from img0_shape
|
||||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||||
@ -771,11 +738,11 @@ def masks2segments(masks, strategy='largest'):
|
|||||||
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
|
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
|
masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
|
||||||
strategy (str): 'concat' or 'largest'. Defaults to largest
|
strategy (str): 'concat' or 'largest'. Defaults to largest
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
segments (List): list of segment masks
|
segments (List): list of segment masks
|
||||||
"""
|
"""
|
||||||
segments = []
|
segments = []
|
||||||
for x in masks.int().cpu().numpy().astype('uint8'):
|
for x in masks.int().cpu().numpy().astype('uint8'):
|
||||||
@ -796,9 +763,9 @@ def clean_str(s):
|
|||||||
Cleans a string by replacing special characters with underscore _
|
Cleans a string by replacing special characters with underscore _
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
s (str): a string needing special characters replaced
|
s (str): a string needing special characters replaced
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(str): a string with special characters replaced by an underscore _
|
(str): a string with special characters replaced by an underscore _
|
||||||
"""
|
"""
|
||||||
return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
|
return re.sub(pattern='[|@#!¡·$€%&()=?¿^*;:,¨´><+]', repl='_', string=s)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user