mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
ultralytics 8.0.133
add torchvision
compatibility check (#3703)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
0821ccb618
commit
c55a98ab8e
@ -43,6 +43,11 @@ keywords: YOLO, Ultralytics, Utils, Checks, image sizing, version updates, font
|
||||
### ::: ultralytics.yolo.utils.checks.check_requirements
|
||||
<br><br>
|
||||
|
||||
## check_torchvision
|
||||
---
|
||||
### ::: ultralytics.yolo.utils.checks.check_torchvision
|
||||
<br><br>
|
||||
|
||||
## check_suffix
|
||||
---
|
||||
### ::: ultralytics.yolo.utils.checks.check_suffix
|
||||
|
@ -66,7 +66,7 @@
|
||||
"import ultralytics\n",
|
||||
"ultralytics.checks()"
|
||||
],
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -102,7 +102,7 @@
|
||||
"# Run inference on an image with YOLOv8n\n",
|
||||
"!yolo predict model=yolov8n.pt source='https://ultralytics.com/images/zidane.jpg'"
|
||||
],
|
||||
"execution_count": 2,
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -169,7 +169,7 @@
|
||||
"# Validate YOLOv8n on COCO128 val\n",
|
||||
"!yolo val model=yolov8n.pt data=coco128.yaml"
|
||||
],
|
||||
"execution_count": 3,
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -293,7 +293,7 @@
|
||||
"# Train YOLOv8n on COCO128 for 3 epochs\n",
|
||||
"!yolo train model=yolov8n.pt data=coco128.yaml epochs=3 imgsz=640"
|
||||
],
|
||||
"execution_count": 4,
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -454,21 +454,21 @@
|
||||
"- 💡 ProTip: Export to [TensorRT](https://developer.nvidia.com/tensorrt) for up to 5x GPU speedup.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"| Format | `format=` | Model |\n",
|
||||
"|----------------------------------------------------------------------------|--------------------|---------------------------|\n",
|
||||
"| [PyTorch](https://pytorch.org/) | - | `yolov8n.pt` |\n",
|
||||
"| [TorchScript](https://pytorch.org/docs/stable/jit.html) | `torchscript` | `yolov8n.torchscript` |\n",
|
||||
"| [ONNX](https://onnx.ai/) | `onnx` | `yolov8n.onnx` |\n",
|
||||
"| [OpenVINO](https://docs.openvino.ai/latest/index.html) | `openvino` | `yolov8n_openvino_model/` |\n",
|
||||
"| [TensorRT](https://developer.nvidia.com/tensorrt) | `engine` | `yolov8n.engine` |\n",
|
||||
"| [CoreML](https://github.com/apple/coremltools) | `coreml` | `yolov8n.mlmodel` |\n",
|
||||
"| [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) | `saved_model` | `yolov8n_saved_model/` |\n",
|
||||
"| [TensorFlow GraphDef](https://www.tensorflow.org/api_docs/python/tf/Graph) | `pb` | `yolov8n.pb` |\n",
|
||||
"| [TensorFlow Lite](https://www.tensorflow.org/lite) | `tflite` | `yolov8n.tflite` |\n",
|
||||
"| [TensorFlow Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n_edgetpu.tflite` |\n",
|
||||
"| [TensorFlow.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n_web_model/` |\n",
|
||||
"| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` |\n",
|
||||
"\n"
|
||||
"| Format | `format` Argument | Model |\n",
|
||||
"|----------------------------------------------------------------------------|-------------------|---------------------------|\n",
|
||||
"| [PyTorch](https://pytorch.org/) | - | `yolov8n.pt` |\n",
|
||||
"| [TorchScript](https://pytorch.org/docs/stable/jit.html) | `torchscript` | `yolov8n.torchscript` |\n",
|
||||
"| [ONNX](https://onnx.ai/) | `onnx` | `yolov8n.onnx` |\n",
|
||||
"| [OpenVINO](https://docs.openvino.ai/latest/index.html) | `openvino` | `yolov8n_openvino_model/` |\n",
|
||||
"| [TensorRT](https://developer.nvidia.com/tensorrt) | `engine` | `yolov8n.engine` |\n",
|
||||
"| [CoreML](https://github.com/apple/coremltools) | `coreml` | `yolov8n.mlmodel` |\n",
|
||||
"| [TensorFlow SavedModel](https://www.tensorflow.org/guide/saved_model) | `saved_model` | `yolov8n_saved_model/` |\n",
|
||||
"| [TensorFlow GraphDef](https://www.tensorflow.org/api_docs/python/tf/Graph) | `pb` | `yolov8n.pb` |\n",
|
||||
"| [TensorFlow Lite](https://www.tensorflow.org/lite) | `tflite` | `yolov8n.tflite` |\n",
|
||||
"| [TensorFlow Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n_edgetpu.tflite` |\n",
|
||||
"| [TensorFlow.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n_web_model/` |\n",
|
||||
"| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` |\n",
|
||||
"| [NCNN](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n_ncnn_model/` |\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "nPZZeNrLCQG6"
|
||||
@ -486,7 +486,7 @@
|
||||
"id": "CYIjW4igCjqD",
|
||||
"outputId": "fc41bf7a-0ea2-41a6-9ec5-dd0455af43bc"
|
||||
},
|
||||
"execution_count": 5,
|
||||
"execution_count": null,
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
@ -533,7 +533,7 @@
|
||||
"results = model.train(data='coco128.yaml', epochs=3) # train the model\n",
|
||||
"results = model.val() # evaluate model performance on the validation set\n",
|
||||
"results = model('https://ultralytics.com/images/bus.jpg') # predict on an image\n",
|
||||
"success = model.export(format='onnx') # export the model to ONNX format"
|
||||
"results = model.export(format='onnx') # export the model to ONNX format"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "bpF9-vS_DAaf"
|
||||
@ -677,9 +677,8 @@
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Git clone and run tests on updates branch\n",
|
||||
"!git clone https://github.com/ultralytics/ultralytics -b updates\n",
|
||||
"%pip install -qe ultralytics\n",
|
||||
"!pytest ultralytics/tests"
|
||||
"!git clone https://github.com/ultralytics/ultralytics -b main\n",
|
||||
"%pip install -qe ultralytics"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "uRKlwxSJdhd1"
|
||||
@ -687,6 +686,18 @@
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Run tests (Git clone only)\n",
|
||||
"!pytest ultralytics/tests"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "GtPlh7mcCGZX"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = '8.0.132'
|
||||
__version__ = '8.0.133'
|
||||
|
||||
from ultralytics.hub import start
|
||||
from ultralytics.vit.rtdetr import RTDETR
|
||||
|
@ -25,11 +25,11 @@ class HUBTrainingSession:
|
||||
model_id (str): Identifier for the YOLOv5 model being trained.
|
||||
model_url (str): URL for the model in Ultralytics HUB.
|
||||
api_url (str): API URL for the model in Ultralytics HUB.
|
||||
auth_header (Dict): Authentication header for the Ultralytics HUB API requests.
|
||||
rate_limits (Dict): Rate limits for different API calls (in seconds).
|
||||
timers (Dict): Timers for rate limiting.
|
||||
metrics_queue (Dict): Queue for the model's metrics.
|
||||
model (Dict): Model data fetched from Ultralytics HUB.
|
||||
auth_header (dict): Authentication header for the Ultralytics HUB API requests.
|
||||
rate_limits (dict): Rate limits for different API calls (in seconds).
|
||||
timers (dict): Timers for rate limiting.
|
||||
metrics_queue (dict): Queue for the model's metrics.
|
||||
model (dict): Model data fetched from Ultralytics HUB.
|
||||
alive (bool): Indicates if the heartbeat loop is active.
|
||||
"""
|
||||
|
||||
|
@ -601,7 +601,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
||||
|
||||
|
||||
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
# Parse a YOLO model.yaml dictionary into a PyTorch model
|
||||
"""Parse a YOLO model.yaml dictionary into a PyTorch model."""
|
||||
import ast
|
||||
|
||||
# Args
|
||||
|
@ -171,8 +171,8 @@ def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
|
||||
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
|
||||
|
||||
Args:
|
||||
custom (Dict): a dictionary of custom configuration options
|
||||
base (Dict): a dictionary of base configuration options
|
||||
custom (dict): a dictionary of custom configuration options
|
||||
base (dict): a dictionary of base configuration options
|
||||
"""
|
||||
custom = _handle_deprecation(custom)
|
||||
base, custom = (set(x.keys()) for x in (base, custom))
|
||||
|
@ -642,7 +642,8 @@ class CopyPaste:
|
||||
|
||||
|
||||
class Albumentations:
|
||||
# YOLOv8 Albumentations class (optional, only used if package is installed)
|
||||
"""YOLOv8 Albumentations class (optional, only used if package is installed)"""
|
||||
|
||||
def __init__(self, p=1.0):
|
||||
"""Initialize the transform object for YOLO bbox formatted params."""
|
||||
self.p = p
|
||||
@ -819,7 +820,7 @@ def classify_albumentations(
|
||||
std=(1.0, 1.0, 1.0), # IMAGENET_STD
|
||||
auto_aug=False,
|
||||
):
|
||||
# YOLOv8 classification Albumentations (optional, only used if package is installed)
|
||||
"""YOLOv8 classification Albumentations (optional, only used if package is installed)."""
|
||||
prefix = colorstr('albumentations: ')
|
||||
try:
|
||||
import albumentations as A
|
||||
@ -851,7 +852,8 @@ def classify_albumentations(
|
||||
|
||||
|
||||
class ClassifyLetterBox:
|
||||
# YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
||||
"""YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])"""
|
||||
|
||||
def __init__(self, size=(640, 640), auto=False, stride=32):
|
||||
"""Resizes image and crops it to center with max dimensions 'h' and 'w'."""
|
||||
super().__init__()
|
||||
@ -871,7 +873,8 @@ class ClassifyLetterBox:
|
||||
|
||||
|
||||
class CenterCrop:
|
||||
# YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
|
||||
"""YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])"""
|
||||
|
||||
def __init__(self, size=640):
|
||||
"""Converts an image from numpy array to PyTorch tensor."""
|
||||
super().__init__()
|
||||
@ -885,7 +888,8 @@ class CenterCrop:
|
||||
|
||||
|
||||
class ToTensor:
|
||||
# YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
|
||||
"""YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])."""
|
||||
|
||||
def __init__(self, half=False):
|
||||
"""Initialize YOLOv8 ToTensor object with optional half-precision support."""
|
||||
super().__init__()
|
||||
|
@ -63,7 +63,7 @@ class _RepeatSampler:
|
||||
|
||||
|
||||
def seed_worker(worker_id): # noqa
|
||||
# Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
|
||||
"""Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader."""
|
||||
worker_seed = torch.initial_seed() % 2 ** 32
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
@ -29,7 +29,8 @@ class SourceTypes:
|
||||
|
||||
|
||||
class LoadStreams:
|
||||
# YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
|
||||
"""YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`."""
|
||||
|
||||
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1):
|
||||
"""Initialize instance variables and check for consistent input stream shapes."""
|
||||
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
||||
@ -116,7 +117,8 @@ class LoadStreams:
|
||||
|
||||
|
||||
class LoadScreenshots:
|
||||
# YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`
|
||||
"""YOLOv8 screenshot dataloader, i.e. `yolo predict source=screen`."""
|
||||
|
||||
def __init__(self, source, imgsz=640):
|
||||
"""source = [screen_number left top width height] (pixels)."""
|
||||
check_requirements('mss')
|
||||
@ -158,7 +160,8 @@ class LoadScreenshots:
|
||||
|
||||
|
||||
class LoadImages:
|
||||
# YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`
|
||||
"""YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`."""
|
||||
|
||||
def __init__(self, path, imgsz=640, vid_stride=1):
|
||||
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
|
||||
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
|
||||
|
@ -278,7 +278,7 @@ def check_cls_dataset(dataset: str, split=''):
|
||||
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the following keys:
|
||||
(dict): A dictionary containing the following keys:
|
||||
- 'train' (Path): The directory path containing the training set of the dataset.
|
||||
- 'val' (Path): The directory path containing the validation set of the dataset.
|
||||
- 'test' (Path): The directory path containing the test set of the dataset.
|
||||
|
@ -213,16 +213,18 @@ class Results(SimpleClass):
|
||||
assert type(line_width) == int, '`line_width` should be of int type, i.e, line_width=3'
|
||||
|
||||
names = self.names
|
||||
annotator = Annotator(deepcopy(self.orig_img if img is None else img),
|
||||
line_width,
|
||||
font_size,
|
||||
font,
|
||||
pil,
|
||||
example=names)
|
||||
pred_boxes, show_boxes = self.boxes, boxes
|
||||
pred_masks, show_masks = self.masks, masks
|
||||
pred_probs, show_probs = self.probs, probs
|
||||
keypoints = self.keypoints
|
||||
annotator = Annotator(
|
||||
deepcopy(self.orig_img if img is None else img),
|
||||
line_width,
|
||||
font_size,
|
||||
font,
|
||||
pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True
|
||||
example=names)
|
||||
|
||||
# Plot Segment results
|
||||
if pred_masks and show_masks:
|
||||
if img_gpu is None:
|
||||
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
|
||||
@ -231,6 +233,7 @@ class Results(SimpleClass):
|
||||
idx = pred_boxes.cls if pred_boxes else range(len(pred_masks))
|
||||
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=img_gpu)
|
||||
|
||||
# Plot Detect results
|
||||
if pred_boxes and show_boxes:
|
||||
for d in reversed(pred_boxes):
|
||||
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
||||
@ -238,12 +241,15 @@ class Results(SimpleClass):
|
||||
label = (f'{name} {conf:.2f}' if conf else name) if labels else None
|
||||
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
|
||||
# Plot Classify results
|
||||
if pred_probs is not None and show_probs:
|
||||
text = f"{', '.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)}, "
|
||||
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||
text = ',\n'.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)
|
||||
x = round(self.orig_shape[0] * 0.03)
|
||||
annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||
|
||||
if keypoints is not None:
|
||||
for k in reversed(keypoints.data):
|
||||
# Plot Pose results
|
||||
if self.keypoints is not None:
|
||||
for k in reversed(self.keypoints.data):
|
||||
annotator.kpts(k, self.orig_shape, kpt_line=kpt_line)
|
||||
|
||||
return annotator.result()
|
||||
|
@ -211,6 +211,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
"""
|
||||
prefix = colorstr('red', 'bold', 'requirements:')
|
||||
check_python() # check python version
|
||||
check_torchvision() # check torch-torchvision compatibility
|
||||
file = None
|
||||
if isinstance(requirements, Path): # requirements.txt file
|
||||
file = requirements.resolve()
|
||||
@ -255,6 +256,34 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
return True
|
||||
|
||||
|
||||
def check_torchvision():
|
||||
"""
|
||||
Checks the installed versions of PyTorch and Torchvision to ensure they're compatible.
|
||||
|
||||
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
|
||||
to the provided compatibility table based on https://github.com/pytorch/vision#installation. The
|
||||
compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible
|
||||
Torchvision versions.
|
||||
"""
|
||||
|
||||
import torchvision
|
||||
|
||||
# Compatibility table
|
||||
compatibility_table = {'2.0': ['0.15'], '1.13': ['0.14'], '1.12': ['0.13']}
|
||||
|
||||
# Extract only the major and minor versions
|
||||
v_torch = '.'.join(torch.__version__.split('+')[0].split('.')[:2])
|
||||
v_torchvision = '.'.join(torchvision.__version__.split('+')[0].split('.')[:2])
|
||||
|
||||
if v_torch in compatibility_table:
|
||||
compatible_versions = compatibility_table[v_torch]
|
||||
if all(pkg.parse_version(v_torchvision) != pkg.parse_version(v) for v in compatible_versions):
|
||||
print(f'WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n'
|
||||
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
|
||||
"'pip install -U torch torchvision' to update both.\n"
|
||||
'For a full compatibility table see https://github.com/pytorch/vision#installation')
|
||||
|
||||
|
||||
def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
|
||||
"""Check file(s) for acceptable suffix."""
|
||||
if file and suffix:
|
||||
@ -402,7 +431,7 @@ def check_amp(model):
|
||||
|
||||
|
||||
def git_describe(path=ROOT): # path must be a directory
|
||||
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
||||
"""Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe."""
|
||||
try:
|
||||
assert (Path(path) / '.git').is_dir()
|
||||
return subprocess.check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
||||
|
@ -91,7 +91,7 @@ def get_latest_run(search_dir='.'):
|
||||
|
||||
|
||||
def make_dirs(dir='new_dir/'):
|
||||
# Create folders
|
||||
"""Create directories."""
|
||||
dir = Path(dir)
|
||||
if dir.exists():
|
||||
shutil.rmtree(dir) # delete dir
|
||||
|
@ -55,12 +55,17 @@ class Profile(contextlib.ContextDecorator):
|
||||
return time.time()
|
||||
|
||||
|
||||
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
|
||||
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
|
||||
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
|
||||
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
|
||||
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
|
||||
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
|
||||
def coco80_to_coco91_class(): #
|
||||
"""
|
||||
Converts 80-index (val2014) to 91-index (paper).
|
||||
For details see https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/.
|
||||
|
||||
Example:
|
||||
a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
|
||||
b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
|
||||
x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
|
||||
x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
|
||||
"""
|
||||
return [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
||||
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
||||
|
@ -34,7 +34,7 @@ _torch_save = torch.save # copy to avoid recursion errors
|
||||
|
||||
|
||||
def torch_save(*args, **kwargs):
|
||||
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
|
||||
"""Use dill (if exists) to serialize the lambda functions where pickle does not do this."""
|
||||
try:
|
||||
import dill as pickle
|
||||
except ImportError:
|
||||
|
@ -21,7 +21,8 @@ from .ops import clip_boxes, scale_image, xywh2xyxy, xyxy2xywh
|
||||
|
||||
|
||||
class Colors:
|
||||
# Ultralytics color palette https://ultralytics.com/
|
||||
"""Ultralytics color palette https://ultralytics.com/."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
|
||||
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
||||
@ -48,7 +49,8 @@ colors = Colors() # create instance for 'from utils.plots import colors'
|
||||
|
||||
|
||||
class Annotator:
|
||||
# YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
|
||||
"""YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations."""
|
||||
|
||||
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
|
||||
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
|
||||
@ -204,7 +206,14 @@ class Annotator:
|
||||
self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color)
|
||||
# Using `txt_color` for background and draw fg with white color
|
||||
txt_color = (255, 255, 255)
|
||||
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
||||
if '\n' in text:
|
||||
lines = text.split('\n')
|
||||
_, h = self.font.getsize(text)
|
||||
for line in lines:
|
||||
self.draw.text(xy, line, fill=txt_color, font=self.font)
|
||||
xy[1] += h
|
||||
else:
|
||||
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
||||
else:
|
||||
if box_style:
|
||||
tf = max(self.lw - 1, 1) # font thickness
|
||||
@ -310,7 +319,7 @@ def plot_images(images,
|
||||
fname='images.jpg',
|
||||
names=None,
|
||||
on_plot=None):
|
||||
# Plot image grid with labels
|
||||
"""Plot image grid with labels."""
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = images.cpu().float().numpy()
|
||||
if isinstance(cls, torch.Tensor):
|
||||
|
@ -232,7 +232,7 @@ def get_flops(model, imgsz=640):
|
||||
|
||||
|
||||
def get_flops_with_torch_profiler(model, imgsz=640):
|
||||
# Compute model FLOPs (thop alternative)
|
||||
"""Compute model FLOPs (thop alternative)."""
|
||||
model = de_parallel(model)
|
||||
p = next(model.parameters())
|
||||
stride = (max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32) * 2 # max stride
|
||||
|
Loading…
x
Reference in New Issue
Block a user