mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
ultralytics 8.0.133
add torchvision
compatibility check (#3703)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
0821ccb618
commit
c55a98ab8e
@ -43,6 +43,11 @@ keywords: YOLO, Ultralytics, Utils, Checks, image sizing, version updates, font
|
|||||||
### ::: ultralytics.yolo.utils.checks.check_requirements
|
### ::: ultralytics.yolo.utils.checks.check_requirements
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
|
## check_torchvision
|
||||||
|
---
|
||||||
|
### ::: ultralytics.yolo.utils.checks.check_torchvision
|
||||||
|
<br><br>
|
||||||
|
|
||||||
## check_suffix
|
## check_suffix
|
||||||
---
|
---
|
||||||
### ::: ultralytics.yolo.utils.checks.check_suffix
|
### ::: ultralytics.yolo.utils.checks.check_suffix
|
||||||
|
@ -66,7 +66,7 @@
|
|||||||
"import ultralytics\n",
|
"import ultralytics\n",
|
||||||
"ultralytics.checks()"
|
"ultralytics.checks()"
|
||||||
],
|
],
|
||||||
"execution_count": 1,
|
"execution_count": null,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
@ -102,7 +102,7 @@
|
|||||||
"# Run inference on an image with YOLOv8n\n",
|
"# Run inference on an image with YOLOv8n\n",
|
||||||
"!yolo predict model=yolov8n.pt source='https://ultralytics.com/images/zidane.jpg'"
|
"!yolo predict model=yolov8n.pt source='https://ultralytics.com/images/zidane.jpg'"
|
||||||
],
|
],
|
||||||
"execution_count": 2,
|
"execution_count": null,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
@ -169,7 +169,7 @@
|
|||||||
"# Validate YOLOv8n on COCO128 val\n",
|
"# Validate YOLOv8n on COCO128 val\n",
|
||||||
"!yolo val model=yolov8n.pt data=coco128.yaml"
|
"!yolo val model=yolov8n.pt data=coco128.yaml"
|
||||||
],
|
],
|
||||||
"execution_count": 3,
|
"execution_count": null,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
@ -293,7 +293,7 @@
|
|||||||
"# Train YOLOv8n on COCO128 for 3 epochs\n",
|
"# Train YOLOv8n on COCO128 for 3 epochs\n",
|
||||||
"!yolo train model=yolov8n.pt data=coco128.yaml epochs=3 imgsz=640"
|
"!yolo train model=yolov8n.pt data=coco128.yaml epochs=3 imgsz=640"
|
||||||
],
|
],
|
||||||
"execution_count": 4,
|
"execution_count": null,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
@ -454,21 +454,21 @@
|
|||||||
"- 💡 ProTip: Export to [TensorRT](https://developer.nvidia.com/tensorrt) for up to 5x GPU speedup.\n",
|
"- 💡 ProTip: Export to [TensorRT](https://developer.nvidia.com/tensorrt) for up to 5x GPU speedup.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"| Format | `format=` | Model |\n",
|
"| Format | `format` Argument | Model |\n",
|
||||||
"|----------------------------------------------------------------------------|--------------------|---------------------------|\n",
|
"|----------------------------------------------------------------------------|-------------------|---------------------------|\n",
|
||||||
"| [PyTorch](https://pytorch.org/) | - | `yolov8n.pt` |\n",
|
"| [PyTorch](https://pytorch.org/) | - | `yolov8n.pt` |\n",
|
||||||
"| [TorchScript](https://pytorch.org/docs/stable/jit.html) | `torchscript` | `yolov8n.torchscript` |\n",
|
"| [TorchScript](https://pytorch.org/docs/stable/jit.html) | `torchscript` | `yolov8n.torchscript` |\n",
|
||||||
"| [ONNX](https://onnx.ai/) | `onnx` | `yolov8n.onnx` |\n",
|
"| [ONNX](https://onnx.ai/) | `onnx` | `yolov8n.onnx` |\n",
|
||||||
"| [OpenVINO](https://docs.openvino.ai/latest/index.html) | `openvino` | `yolov8n_openvino_model/` |\n",
|
"| [OpenVINO](https://docs.openvino.ai/latest/index.html) | `openvino` | `yolov8n_openvino_model/` |\n",
|
||||||
"| [TensorRT](https://developer.nvidia.com/tensorrt) | `engine` | `yolov8n.engine` |\n",
|
"| [TensorRT](https://developer.nvidia.com/tensorrt) | `engine` | `yolov8n.engine` |\n",
|
||||||
"| [CoreML](https://github.com/apple/coremltools) | `coreml` | `yolov8n.mlmodel` |\n",
|
"| [CoreML](https://github.com/apple/coremltools) | `coreml` | `yolov8n.mlmodel` |\n",
|
||||||
"| [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) | `saved_model` | `yolov8n_saved_model/` |\n",
|
"| [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) | `saved_model` | `yolov8n_saved_model/` |\n",
|
||||||
"| [TensorFlow GraphDef](https://www.tensorflow.org/api_docs/python/tf/Graph) | `pb` | `yolov8n.pb` |\n",
|
"| [TensorFlow GraphDef](https://www.tensorflow.org/api_docs/python/tf/Graph) | `pb` | `yolov8n.pb` |\n",
|
||||||
"| [TensorFlow Lite](https://www.tensorflow.org/lite) | `tflite` | `yolov8n.tflite` |\n",
|
"| [TensorFlow Lite](https://www.tensorflow.org/lite) | `tflite` | `yolov8n.tflite` |\n",
|
||||||
"| [TensorFlow Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n_edgetpu.tflite` |\n",
|
"| [TensorFlow Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n_edgetpu.tflite` |\n",
|
||||||
"| [TensorFlow.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n_web_model/` |\n",
|
"| [TensorFlow.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n_web_model/` |\n",
|
||||||
"| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` |\n",
|
"| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` |\n",
|
||||||
"\n"
|
"| [NCNN](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n_ncnn_model/` |\n"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "nPZZeNrLCQG6"
|
"id": "nPZZeNrLCQG6"
|
||||||
@ -486,7 +486,7 @@
|
|||||||
"id": "CYIjW4igCjqD",
|
"id": "CYIjW4igCjqD",
|
||||||
"outputId": "fc41bf7a-0ea2-41a6-9ec5-dd0455af43bc"
|
"outputId": "fc41bf7a-0ea2-41a6-9ec5-dd0455af43bc"
|
||||||
},
|
},
|
||||||
"execution_count": 5,
|
"execution_count": null,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
@ -533,7 +533,7 @@
|
|||||||
"results = model.train(data='coco128.yaml', epochs=3) # train the model\n",
|
"results = model.train(data='coco128.yaml', epochs=3) # train the model\n",
|
||||||
"results = model.val() # evaluate model performance on the validation set\n",
|
"results = model.val() # evaluate model performance on the validation set\n",
|
||||||
"results = model('https://ultralytics.com/images/bus.jpg') # predict on an image\n",
|
"results = model('https://ultralytics.com/images/bus.jpg') # predict on an image\n",
|
||||||
"success = model.export(format='onnx') # export the model to ONNX format"
|
"results = model.export(format='onnx') # export the model to ONNX format"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "bpF9-vS_DAaf"
|
"id": "bpF9-vS_DAaf"
|
||||||
@ -677,9 +677,8 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"source": [
|
"source": [
|
||||||
"# Git clone and run tests on updates branch\n",
|
"# Git clone and run tests on updates branch\n",
|
||||||
"!git clone https://github.com/ultralytics/ultralytics -b updates\n",
|
"!git clone https://github.com/ultralytics/ultralytics -b main\n",
|
||||||
"%pip install -qe ultralytics\n",
|
"%pip install -qe ultralytics"
|
||||||
"!pytest ultralytics/tests"
|
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "uRKlwxSJdhd1"
|
"id": "uRKlwxSJdhd1"
|
||||||
@ -687,6 +686,18 @@
|
|||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"outputs": []
|
"outputs": []
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"source": [
|
||||||
|
"# Run tests (Git clone only)\n",
|
||||||
|
"!pytest ultralytics/tests"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"id": "GtPlh7mcCGZX"
|
||||||
|
},
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = '8.0.132'
|
__version__ = '8.0.133'
|
||||||
|
|
||||||
from ultralytics.hub import start
|
from ultralytics.hub import start
|
||||||
from ultralytics.vit.rtdetr import RTDETR
|
from ultralytics.vit.rtdetr import RTDETR
|
||||||
|
@ -25,11 +25,11 @@ class HUBTrainingSession:
|
|||||||
model_id (str): Identifier for the YOLOv5 model being trained.
|
model_id (str): Identifier for the YOLOv5 model being trained.
|
||||||
model_url (str): URL for the model in Ultralytics HUB.
|
model_url (str): URL for the model in Ultralytics HUB.
|
||||||
api_url (str): API URL for the model in Ultralytics HUB.
|
api_url (str): API URL for the model in Ultralytics HUB.
|
||||||
auth_header (Dict): Authentication header for the Ultralytics HUB API requests.
|
auth_header (dict): Authentication header for the Ultralytics HUB API requests.
|
||||||
rate_limits (Dict): Rate limits for different API calls (in seconds).
|
rate_limits (dict): Rate limits for different API calls (in seconds).
|
||||||
timers (Dict): Timers for rate limiting.
|
timers (dict): Timers for rate limiting.
|
||||||
metrics_queue (Dict): Queue for the model's metrics.
|
metrics_queue (dict): Queue for the model's metrics.
|
||||||
model (Dict): Model data fetched from Ultralytics HUB.
|
model (dict): Model data fetched from Ultralytics HUB.
|
||||||
alive (bool): Indicates if the heartbeat loop is active.
|
alive (bool): Indicates if the heartbeat loop is active.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -601,7 +601,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
|||||||
|
|
||||||
|
|
||||||
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
# Parse a YOLO model.yaml dictionary into a PyTorch model
|
"""Parse a YOLO model.yaml dictionary into a PyTorch model."""
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
# Args
|
# Args
|
||||||
|
@ -171,8 +171,8 @@ def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
|
|||||||
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
|
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
custom (Dict): a dictionary of custom configuration options
|
custom (dict): a dictionary of custom configuration options
|
||||||
base (Dict): a dictionary of base configuration options
|
base (dict): a dictionary of base configuration options
|
||||||
"""
|
"""
|
||||||
custom = _handle_deprecation(custom)
|
custom = _handle_deprecation(custom)
|
||||||
base, custom = (set(x.keys()) for x in (base, custom))
|
base, custom = (set(x.keys()) for x in (base, custom))
|
||||||
|
@ -642,7 +642,8 @@ class CopyPaste:
|
|||||||
|
|
||||||
|
|
||||||
class Albumentations:
|
class Albumentations:
|
||||||
# YOLOv8 Albumentations class (optional, only used if package is installed)
|
"""YOLOv8 Albumentations class (optional, only used if package is installed)"""
|
||||||
|
|
||||||
def __init__(self, p=1.0):
|
def __init__(self, p=1.0):
|
||||||
"""Initialize the transform object for YOLO bbox formatted params."""
|
"""Initialize the transform object for YOLO bbox formatted params."""
|
||||||
self.p = p
|
self.p = p
|
||||||
@ -819,7 +820,7 @@ def classify_albumentations(
|
|||||||
std=(1.0, 1.0, 1.0), # IMAGENET_STD
|
std=(1.0, 1.0, 1.0), # IMAGENET_STD
|
||||||
auto_aug=False,
|
auto_aug=False,
|
||||||
):
|
):
|
||||||
# YOLOv8 classification Albumentations (optional, only used if package is installed)
|
"""YOLOv8 classification Albumentations (optional, only used if package is installed)."""
|
||||||
prefix = colorstr('albumentations: ')
|
prefix = colorstr('albumentations: ')
|
||||||
try:
|
try:
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
@ -851,7 +852,8 @@ def classify_albumentations(
|
|||||||
|
|
||||||
|
|
||||||
class ClassifyLetterBox:
|
class ClassifyLetterBox:
|
||||||
# YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
"""YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])"""
|
||||||
|
|
||||||
def __init__(self, size=(640, 640), auto=False, stride=32):
|
def __init__(self, size=(640, 640), auto=False, stride=32):
|
||||||
"""Resizes image and crops it to center with max dimensions 'h' and 'w'."""
|
"""Resizes image and crops it to center with max dimensions 'h' and 'w'."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -871,7 +873,8 @@ class ClassifyLetterBox:
|
|||||||
|
|
||||||
|
|
||||||
class CenterCrop:
|
class CenterCrop:
|
||||||
# YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
|
"""YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])"""
|
||||||
|
|
||||||
def __init__(self, size=640):
|
def __init__(self, size=640):
|
||||||
"""Converts an image from numpy array to PyTorch tensor."""
|
"""Converts an image from numpy array to PyTorch tensor."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -885,7 +888,8 @@ class CenterCrop:
|
|||||||
|
|
||||||
|
|
||||||
class ToTensor:
|
class ToTensor:
|
||||||
# YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
"""YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])."""
|
||||||
|
|
||||||
def __init__(self, half=False):
|
def __init__(self, half=False):
|
||||||
"""Initialize YOLOv8 ToTensor object with optional half-precision support."""
|
"""Initialize YOLOv8 ToTensor object with optional half-precision support."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -63,7 +63,7 @@ class _RepeatSampler:
|
|||||||
|
|
||||||
|
|
||||||
def seed_worker(worker_id): # noqa
|
def seed_worker(worker_id): # noqa
|
||||||
# Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
|
"""Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
|
||||||
worker_seed = torch.initial_seed() % 2 ** 32
|
worker_seed = torch.initial_seed() % 2 ** 32
|
||||||
np.random.seed(worker_seed)
|
np.random.seed(worker_seed)
|
||||||
random.seed(worker_seed)
|
random.seed(worker_seed)
|
||||||
|
@ -29,7 +29,8 @@ class SourceTypes:
|
|||||||
|
|
||||||
|
|
||||||
class LoadStreams:
|
class LoadStreams:
|
||||||
# YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
|
"""YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`."""
|
||||||
|
|
||||||
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1):
|
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1):
|
||||||
"""Initialize instance variables and check for consistent input stream shapes."""
|
"""Initialize instance variables and check for consistent input stream shapes."""
|
||||||
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
||||||
@ -116,7 +117,8 @@ class LoadStreams:
|
|||||||
|
|
||||||
|
|
||||||
class LoadScreenshots:
|
class LoadScreenshots:
|
||||||
# YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`
|
"""YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`."""
|
||||||
|
|
||||||
def __init__(self, source, imgsz=640):
|
def __init__(self, source, imgsz=640):
|
||||||
"""source = [screen_number left top width height] (pixels)."""
|
"""source = [screen_number left top width height] (pixels)."""
|
||||||
check_requirements('mss')
|
check_requirements('mss')
|
||||||
@ -158,7 +160,8 @@ class LoadScreenshots:
|
|||||||
|
|
||||||
|
|
||||||
class LoadImages:
|
class LoadImages:
|
||||||
# YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`
|
"""YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`."""
|
||||||
|
|
||||||
def __init__(self, path, imgsz=640, vid_stride=1):
|
def __init__(self, path, imgsz=640, vid_stride=1):
|
||||||
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
|
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
|
||||||
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
|
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
|
||||||
|
@ -278,7 +278,7 @@ def check_cls_dataset(dataset: str, split=''):
|
|||||||
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
|
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing the following keys:
|
(dict): A dictionary containing the following keys:
|
||||||
- 'train' (Path): The directory path containing the training set of the dataset.
|
- 'train' (Path): The directory path containing the training set of the dataset.
|
||||||
- 'val' (Path): The directory path containing the validation set of the dataset.
|
- 'val' (Path): The directory path containing the validation set of the dataset.
|
||||||
- 'test' (Path): The directory path containing the test set of the dataset.
|
- 'test' (Path): The directory path containing the test set of the dataset.
|
||||||
|
@ -213,16 +213,18 @@ class Results(SimpleClass):
|
|||||||
assert type(line_width) == int, '`line_width` should be of int type, i.e, line_width=3'
|
assert type(line_width) == int, '`line_width` should be of int type, i.e, line_width=3'
|
||||||
|
|
||||||
names = self.names
|
names = self.names
|
||||||
annotator = Annotator(deepcopy(self.orig_img if img is None else img),
|
|
||||||
line_width,
|
|
||||||
font_size,
|
|
||||||
font,
|
|
||||||
pil,
|
|
||||||
example=names)
|
|
||||||
pred_boxes, show_boxes = self.boxes, boxes
|
pred_boxes, show_boxes = self.boxes, boxes
|
||||||
pred_masks, show_masks = self.masks, masks
|
pred_masks, show_masks = self.masks, masks
|
||||||
pred_probs, show_probs = self.probs, probs
|
pred_probs, show_probs = self.probs, probs
|
||||||
keypoints = self.keypoints
|
annotator = Annotator(
|
||||||
|
deepcopy(self.orig_img if img is None else img),
|
||||||
|
line_width,
|
||||||
|
font_size,
|
||||||
|
font,
|
||||||
|
pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
|
||||||
|
example=names)
|
||||||
|
|
||||||
|
# Plot Segment results
|
||||||
if pred_masks and show_masks:
|
if pred_masks and show_masks:
|
||||||
if img_gpu is None:
|
if img_gpu is None:
|
||||||
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
|
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
|
||||||
@ -231,6 +233,7 @@ class Results(SimpleClass):
|
|||||||
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
|
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
|
||||||
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=img_gpu)
|
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=img_gpu)
|
||||||
|
|
||||||
|
# Plot Detect results
|
||||||
if pred_boxes and show_boxes:
|
if pred_boxes and show_boxes:
|
||||||
for d in reversed(pred_boxes):
|
for d in reversed(pred_boxes):
|
||||||
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
||||||
@ -238,12 +241,15 @@ class Results(SimpleClass):
|
|||||||
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
|
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
|
||||||
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||||
|
|
||||||
|
# Plot Classify results
|
||||||
if pred_probs is not None and show_probs:
|
if pred_probs is not None and show_probs:
|
||||||
text = f"{', '.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)}, "
|
text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)
|
||||||
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
x = round(self.orig_shape[0] * 0.03)
|
||||||
|
annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||||
|
|
||||||
if keypoints is not None:
|
# Plot Pose results
|
||||||
for k in reversed(keypoints.data):
|
if self.keypoints is not None:
|
||||||
|
for k in reversed(self.keypoints.data):
|
||||||
annotator.kpts(k, self.orig_shape, kpt_line=kpt_line)
|
annotator.kpts(k, self.orig_shape, kpt_line=kpt_line)
|
||||||
|
|
||||||
return annotator.result()
|
return annotator.result()
|
||||||
|
@ -211,6 +211,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
|||||||
"""
|
"""
|
||||||
prefix = colorstr('red', 'bold', 'requirements:')
|
prefix = colorstr('red', 'bold', 'requirements:')
|
||||||
check_python() # check python version
|
check_python() # check python version
|
||||||
|
check_torchvision() # check torch-torchvision compatibility
|
||||||
file = None
|
file = None
|
||||||
if isinstance(requirements, Path): # requirements.txt file
|
if isinstance(requirements, Path): # requirements.txt file
|
||||||
file = requirements.resolve()
|
file = requirements.resolve()
|
||||||
@ -255,6 +256,34 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def check_torchvision():
|
||||||
|
"""
|
||||||
|
Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
|
||||||
|
|
||||||
|
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
|
||||||
|
to the provided compatibility table based on https://github.com/pytorch/vision#installation. The
|
||||||
|
compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
|
||||||
|
Torchvision versions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
# Compatibility table
|
||||||
|
compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']}
|
||||||
|
|
||||||
|
# Extract only the major and minor versions
|
||||||
|
v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
|
||||||
|
v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2])
|
||||||
|
|
||||||
|
if v_torch in compatibility_table:
|
||||||
|
compatible_versions = compatibility_table[v_torch]
|
||||||
|
if all(pkg.parse_version(v_torchvision) != pkg.parse_version(v) for v in compatible_versions):
|
||||||
|
print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n'
|
||||||
|
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
|
||||||
|
"'pip install -U torch torchvision' to update both.\n"
|
||||||
|
'For a full compatibility table see https://github.com/pytorch/vision#installation')
|
||||||
|
|
||||||
|
|
||||||
def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
|
def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
|
||||||
"""Check file(s) for acceptable suffix."""
|
"""Check file(s) for acceptable suffix."""
|
||||||
if file and suffix:
|
if file and suffix:
|
||||||
@ -402,7 +431,7 @@ def check_amp(model):
|
|||||||
|
|
||||||
|
|
||||||
def git_describe(path=ROOT): # path must be a directory
|
def git_describe(path=ROOT): # path must be a directory
|
||||||
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
|
||||||
try:
|
try:
|
||||||
assert (Path(path) / '.git').is_dir()
|
assert (Path(path) / '.git').is_dir()
|
||||||
return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
||||||
|
@ -91,7 +91,7 @@ def get_latest_run(search_dir='.'):
|
|||||||
|
|
||||||
|
|
||||||
def make_dirs(dir='new_dir/'):
|
def make_dirs(dir='new_dir/'):
|
||||||
# Create folders
|
"""Create directories."""
|
||||||
dir = Path(dir)
|
dir = Path(dir)
|
||||||
if dir.exists():
|
if dir.exists():
|
||||||
shutil.rmtree(dir) # delete dir
|
shutil.rmtree(dir) # delete dir
|
||||||
|
@ -55,12 +55,17 @@ class Profile(contextlib.ContextDecorator):
|
|||||||
return time.time()
|
return time.time()
|
||||||
|
|
||||||
|
|
||||||
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
|
def coco80_to_coco91_class(): #
|
||||||
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
|
"""
|
||||||
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
|
Converts 80-index (val2014) to 91-index (paper).
|
||||||
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
|
For details see https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/.
|
||||||
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
|
|
||||||
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
|
Example:
|
||||||
|
a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
|
||||||
|
b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
|
||||||
|
x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
|
||||||
|
x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
|
||||||
|
"""
|
||||||
return [
|
return [
|
||||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
||||||
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
||||||
|
@ -34,7 +34,7 @@ _torch_save = torch.save # copy to avoid recursion errors
|
|||||||
|
|
||||||
|
|
||||||
def torch_save(*args, **kwargs):
|
def torch_save(*args, **kwargs):
|
||||||
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
|
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this."""
|
||||||
try:
|
try:
|
||||||
import dill as pickle
|
import dill as pickle
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -21,7 +21,8 @@ from .ops import clip_boxes, scale_image, xywh2xyxy, xyxy2xywh
|
|||||||
|
|
||||||
|
|
||||||
class Colors:
|
class Colors:
|
||||||
# Ultralytics color palette https://ultralytics.com/
|
"""Ultralytics color palette https://ultralytics.com/."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
|
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
|
||||||
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
||||||
@ -48,7 +49,8 @@ colors = Colors() # create instance for 'from utils.plots import colors'
|
|||||||
|
|
||||||
|
|
||||||
class Annotator:
|
class Annotator:
|
||||||
# YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
|
"""YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations."""
|
||||||
|
|
||||||
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||||
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
|
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
|
||||||
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
|
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
|
||||||
@ -204,7 +206,14 @@ class Annotator:
|
|||||||
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
|
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
|
||||||
# Using `txt_color` for background and draw fg with white color
|
# Using `txt_color` for background and draw fg with white color
|
||||||
txt_color = (255, 255, 255)
|
txt_color = (255, 255, 255)
|
||||||
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
if '\n' in text:
|
||||||
|
lines = text.split('\n')
|
||||||
|
_, h = self.font.getsize(text)
|
||||||
|
for line in lines:
|
||||||
|
self.draw.text(xy, line, fill=txt_color, font=self.font)
|
||||||
|
xy[1] += h
|
||||||
|
else:
|
||||||
|
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
||||||
else:
|
else:
|
||||||
if box_style:
|
if box_style:
|
||||||
tf = max(self.lw - 1, 1) # font thickness
|
tf = max(self.lw - 1, 1) # font thickness
|
||||||
@ -310,7 +319,7 @@ def plot_images(images,
|
|||||||
fname='images.jpg',
|
fname='images.jpg',
|
||||||
names=None,
|
names=None,
|
||||||
on_plot=None):
|
on_plot=None):
|
||||||
# Plot image grid with labels
|
"""Plot image grid with labels."""
|
||||||
if isinstance(images, torch.Tensor):
|
if isinstance(images, torch.Tensor):
|
||||||
images = images.cpu().float().numpy()
|
images = images.cpu().float().numpy()
|
||||||
if isinstance(cls, torch.Tensor):
|
if isinstance(cls, torch.Tensor):
|
||||||
|
@ -232,7 +232,7 @@ def get_flops(model, imgsz=640):
|
|||||||
|
|
||||||
|
|
||||||
def get_flops_with_torch_profiler(model, imgsz=640):
|
def get_flops_with_torch_profiler(model, imgsz=640):
|
||||||
# Compute model FLOPs (thop alternative)
|
"""Compute model FLOPs (thop alternative)."""
|
||||||
model = de_parallel(model)
|
model = de_parallel(model)
|
||||||
p = next(model.parameters())
|
p = next(model.parameters())
|
||||||
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride
|
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride
|
||||||
|
Loading…
x
Reference in New Issue
Block a user