mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-25 15:14:22 +08:00
ultralytics 8.0.62
HUB Syntax updates and fixes (#1795)
Co-authored-by: Danny Kim <imbird0312@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: MagicCodess <32194768+MagicCodess@users.noreply.github.com> Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Amjad Alsharafi <26300843+Amjad50@users.noreply.github.com>
This commit is contained in:
parent
4198570a4b
commit
37274c9845
30
.github/workflows/ci.yaml
vendored
30
.github/workflows/ci.yaml
vendored
@ -73,6 +73,36 @@ jobs:
|
|||||||
hub.login(key)
|
hub.login(key)
|
||||||
model = YOLO('https://hub.ultralytics.com/models/' + model_id)
|
model = YOLO('https://hub.ultralytics.com/models/' + model_id)
|
||||||
model.train()
|
model.train()
|
||||||
|
- name: Test HUB training (Python Usage 3)
|
||||||
|
shell: python
|
||||||
|
env:
|
||||||
|
APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }}
|
||||||
|
run: |
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from ultralytics import YOLO, hub
|
||||||
|
from ultralytics.yolo.utils import USER_CONFIG_DIR
|
||||||
|
Path(USER_CONFIG_DIR / 'settings.yaml').unlink()
|
||||||
|
key = os.environ['APIKEY']
|
||||||
|
hub.reset_model(key)
|
||||||
|
model = YOLO(key)
|
||||||
|
model.train()
|
||||||
|
- name: Test HUB training (Python Usage 4)
|
||||||
|
shell: python
|
||||||
|
env:
|
||||||
|
APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }}
|
||||||
|
run: |
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from ultralytics import YOLO, hub
|
||||||
|
from ultralytics.yolo.utils import USER_CONFIG_DIR
|
||||||
|
Path(USER_CONFIG_DIR / 'settings.yaml').unlink()
|
||||||
|
key = os.environ['APIKEY']
|
||||||
|
hub.reset_model(key)
|
||||||
|
key, model_id = key.split('_')
|
||||||
|
hub.login(key)
|
||||||
|
model = YOLO(model_id)
|
||||||
|
model.train()
|
||||||
|
|
||||||
Benchmarks:
|
Benchmarks:
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
@ -56,7 +56,7 @@ repos:
|
|||||||
name: PEP8
|
name: PEP8
|
||||||
|
|
||||||
- repo: https://github.com/codespell-project/codespell
|
- repo: https://github.com/codespell-project/codespell
|
||||||
rev: v2.2.2
|
rev: v2.2.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: codespell
|
- id: codespell
|
||||||
args:
|
args:
|
||||||
|
@ -216,10 +216,20 @@ masks, classification logits, etc.) found in the results object
|
|||||||
res_plotted = res[0].plot()
|
res_plotted = res[0].plot()
|
||||||
cv2.imshow("result", res_plotted)
|
cv2.imshow("result", res_plotted)
|
||||||
```
|
```
|
||||||
|
| Argument | Description |
|
||||||
|
| ----------- | ------------- |
|
||||||
|
| `conf (bool)` | Whether to plot the detection confidence score. |
|
||||||
|
| `line_width (float, optional)` | The line width of the bounding boxes. If None, it is scaled to the image size. |
|
||||||
|
| `font_size (float, optional)` | The font size of the text. If None, it is scaled to the image size. |
|
||||||
|
| `font (str)` | The font to use for the text. |
|
||||||
|
| `pil (bool)` | Whether to return the image as a PIL Image. |
|
||||||
|
| `example (str)` | An example string to display. Useful for indicating the expected format of the output. |
|
||||||
|
| `img (numpy.ndarray)` | Plot to another image. if not, plot to original image. |
|
||||||
|
| `labels (bool)` | Whether to plot the label of bounding boxes. |
|
||||||
|
| `boxes (bool)` | Whether to plot the bounding boxes. |
|
||||||
|
| `masks (bool)` | Whether to plot the masks. |
|
||||||
|
| `probs (bool)` | Whether to plot classification probability. |
|
||||||
|
|
||||||
- `show_conf (bool)`: Show confidence
|
|
||||||
- `line_width (Float)`: The line width of boxes. Automatically scaled to img size if not provided
|
|
||||||
- `font_size (Float)`: The font size of . Automatically scaled to img size if not provided
|
|
||||||
|
|
||||||
## Streaming Source `for`-loop
|
## Streaming Source `for`-loop
|
||||||
|
|
||||||
|
@ -136,8 +136,8 @@ The prediction settings for YOLO models encompass a range of hyperparameters and
|
|||||||
| `save_txt` | `False` | save results as .txt file |
|
| `save_txt` | `False` | save results as .txt file |
|
||||||
| `save_conf` | `False` | save results with confidence scores |
|
| `save_conf` | `False` | save results with confidence scores |
|
||||||
| `save_crop` | `False` | save cropped images with results |
|
| `save_crop` | `False` | save cropped images with results |
|
||||||
| `hide_labels` | `False` | hide labels |
|
| `show_labels` | `True` | show object labels in plots |
|
||||||
| `hide_conf` | `False` | hide confidence scores |
|
| `show_conf` | `True` | show object confidence scores in plots |
|
||||||
| `max_det` | `300` | maximum number of detections per image |
|
| `max_det` | `300` | maximum number of detections per image |
|
||||||
| `vid_stride` | `False` | video frame-rate stride |
|
| `vid_stride` | `False` | video frame-rate stride |
|
||||||
| `line_thickness` | `3` | bounding box thickness (pixels) |
|
| `line_thickness` | `3` | bounding box thickness (pixels) |
|
||||||
|
@ -1,22 +1,24 @@
|
|||||||
This is a list of real-world applications and walkthroughs. These can be folders of either python files or notebooks .
|
## Ultralytics YOLOv8 Example Applications
|
||||||
|
|
||||||
## Ultralytics YOLO example applications
|
This repository features a collection of real-world applications and walkthroughs, provided as either Python files or notebooks. Explore the examples below to see how YOLOv8 can be integrated into various applications.
|
||||||
|
|
||||||
|
### Ultralytics YOLO Example Applications
|
||||||
|
|
||||||
| Title | Format | Contributor |
|
| Title | Format | Contributor |
|
||||||
| ------------------------------------------------------------------------ | ------------------ | --------------------------------------------------- |
|
| ------------------------------------------------------------------------ | ------------------ | --------------------------------------------------- |
|
||||||
| [YOLO ONNX detection Inference with C++](./YOLOv8-CPP-Inference) | C++/ONNX | [Justas Bartnykas](https://github.com/JustasBart) |
|
| [YOLO ONNX Detection Inference with C++](./YOLOv8-CPP-Inference) | C++/ONNX | [Justas Bartnykas](https://github.com/JustasBart) |
|
||||||
| [YOLO OpenCV ONNX detection Python](./YOLOv8-OpenCV-ONNX-Python) | OpenCV/Python/ONNX | [Farid Inawan](https://github.com/frdteknikelektro) |
|
| [YOLO OpenCV ONNX Detection Python](./YOLOv8-OpenCV-ONNX-Python) | OpenCV/Python/ONNX | [Farid Inawan](https://github.com/frdteknikelektro) |
|
||||||
| [YOLO .Net ONNX detection C#](https://www.nuget.org/packages/Yolov8.Net) | C# .Net | [Samuel Stainback](https://github.com/sstainba) |
|
| [YOLO .Net ONNX Detection C#](https://www.nuget.org/packages/Yolov8.Net) | C# .Net | [Samuel Stainback](https://github.com/sstainba) |
|
||||||
|
|
||||||
## How can you contribute ?
|
### How to Contribute
|
||||||
|
|
||||||
We're looking for examples, applications and guides from the community. Here's how you can contribute:
|
We welcome contributions from the community in the form of examples, applications, and guides. To contribute, please follow these steps:
|
||||||
|
|
||||||
- Make a PR with `[Example]` prefix in title after adding your project folder in the examples/ folder of the repository
|
1. Create a pull request (PR) with the `[Example]` prefix in the title, adding your project folder to the `examples/` directory in the repository.
|
||||||
- The project should satisfy these conditions:
|
1. Ensure that your project meets the following criteria:
|
||||||
- It should use ultralytics framework
|
- Utilizes the `ultralytics` package.
|
||||||
- It have a README.md with instructions to run the project
|
- Includes a `README.md` file with instructions on how to run the project.
|
||||||
- It should avoid adding large assets or dependencies unless absolutely needed
|
- Avoids adding large assets or dependencies unless absolutely necessary.
|
||||||
- The contributor is expected to help out in issues related to their examples
|
- The contributor is expected to provide support for issues related to their examples.
|
||||||
|
|
||||||
If you're unsure about any of these requirements, make a PR and we'll happy to guide you
|
If you have any questions or concerns about these requirements, please submit a PR, and we will be more than happy to guide you.
|
||||||
|
@ -1,17 +1,20 @@
|
|||||||
# yolov8/yolov5 Inference C++
|
# YOLOv8/YOLOv5 Inference C++
|
||||||
|
|
||||||
Usage:
|
This example demonstrates how to perform inference using YOLOv8 and YOLOv5 models in C++ with OpenCV's DNN API.
|
||||||
|
|
||||||
```
|
## Usage
|
||||||
# git clone ultralytics
|
|
||||||
|
```commandline
|
||||||
|
git clone ultralytics
|
||||||
|
cd ultralytics
|
||||||
pip install .
|
pip install .
|
||||||
cd examples/cpp_
|
cd examples/cpp_
|
||||||
|
|
||||||
Add a **yolov8\_.onnx** and/or **yolov5\_.onnx** model(s) to the ultralytics folder.
|
# Add a **yolov8\_.onnx** and/or **yolov5\_.onnx** model(s) to the ultralytics folder.
|
||||||
Edit the **main.cpp** to change the **projectBasePath** to match your user.
|
# Edit the **main.cpp** to change the **projectBasePath** to match your user.
|
||||||
|
|
||||||
Note that by default the CMake file will try and import the CUDA library to be used with the OpenCVs dnn (cuDNN) GPU Inference.
|
# Note that by default the CMake file will try and import the CUDA library to be used with the OpenCVs dnn (cuDNN) GPU Inference.
|
||||||
If your OpenCV build does not use CUDA/cuDNN you can remove that import call and run the example on CPU.
|
# If your OpenCV build does not use CUDA/cuDNN you can remove that import call and run the example on CPU.
|
||||||
|
|
||||||
mkdir build
|
mkdir build
|
||||||
cd build
|
cd build
|
||||||
@ -20,24 +23,18 @@ make
|
|||||||
./Yolov8CPPInference
|
./Yolov8CPPInference
|
||||||
```
|
```
|
||||||
|
|
||||||
To export yolov8 models:
|
## Exporting YOLOv8 and YOLOv5 Models
|
||||||
|
|
||||||
```
|
To export YOLOv8 models:
|
||||||
yolo export \
|
|
||||||
model=yolov8s.pt \
|
```commandline
|
||||||
imgsz=[480,640] \
|
yolo export model=yolov8s.pt imgsz=480,640 format=onnx opset=12
|
||||||
format=onnx \
|
|
||||||
opset=12
|
|
||||||
```
|
```
|
||||||
|
|
||||||
To export yolov5 models:
|
To export YOLOv5 models:
|
||||||
|
|
||||||
```
|
```commandline
|
||||||
python3 export.py \
|
python3 export.py --weights yolov5s.pt --img 480 640 --include onnx --opset 12
|
||||||
--weights yolov5s.pt \
|
|
||||||
--img 480 640 \
|
|
||||||
--include onnx \
|
|
||||||
--opset 12
|
|
||||||
```
|
```
|
||||||
|
|
||||||
yolov8s.onnx:
|
yolov8s.onnx:
|
||||||
@ -48,10 +45,6 @@ yolov5s.onnx:
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
This repository is based on OpenCVs dnn API to run an ONNX exported model of either yolov5/yolov8 (In theory should work
|
This repository utilizes OpenCV's DNN API to run ONNX exported models of YOLOv5 and YOLOv8. In theory, it should work for YOLOv6 and YOLOv7 as well, but they have not been tested. Note that the example networks are exported with rectangular (640x480) resolutions, but any exported resolution will work. You may want to use the letterbox approach for square images, depending on your use case.
|
||||||
for yolov6 and yolov7 but not tested). Note that for this example the networks are exported as rectangular (640x480)
|
|
||||||
resolutions, but it would work for any resolution that you export as although you might want to use the letterBox
|
|
||||||
approach for square images depending on your use-case.
|
|
||||||
|
|
||||||
The **main** branch version is based on using Qt as a GUI wrapper the main interest here is the **Inference** class file
|
The **main** branch version uses Qt as a GUI wrapper. The primary focus here is the **Inference** class file, which demonstrates how to transpose YOLOv8 models to work as YOLOv5 models.
|
||||||
which shows how to transpose yolov8 models to work as yolov5 models.
|
|
||||||
|
@ -83,7 +83,7 @@ std::vector<Detection> Inference::runInference(const cv::Mat &input)
|
|||||||
{
|
{
|
||||||
float confidence = data[4];
|
float confidence = data[4];
|
||||||
|
|
||||||
if (confidence >= modelConfidenseThreshold)
|
if (confidence >= modelConfidenceThreshold)
|
||||||
{
|
{
|
||||||
float *classes_scores = data+5;
|
float *classes_scores = data+5;
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ private:
|
|||||||
|
|
||||||
cv::Size2f modelShape{};
|
cv::Size2f modelShape{};
|
||||||
|
|
||||||
float modelConfidenseThreshold {0.25};
|
float modelConfidenceThreshold {0.25};
|
||||||
float modelScoreThreshold {0.45};
|
float modelScoreThreshold {0.45};
|
||||||
float modelNMSThreshold {0.50};
|
float modelNMSThreshold {0.50};
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ def main(onnx_model, input_image):
|
|||||||
image[0:height, 0:width] = original_image
|
image[0:height, 0:width] = original_image
|
||||||
scale = length / 640
|
scale = length / 640
|
||||||
|
|
||||||
blob = cv2.dnn.blobFromImage(image, scalefactor=1 / 255, size=(640, 640))
|
blob = cv2.dnn.blobFromImage(image, scalefactor=1 / 255, size=(640, 640), swapRB=True)
|
||||||
model.setInput(blob)
|
model.setInput(blob)
|
||||||
outputs = model.forward()
|
outputs = model.forward()
|
||||||
|
|
||||||
|
@ -207,10 +207,10 @@ def test_predict_callback_and_setup():
|
|||||||
def test_result():
|
def test_result():
|
||||||
model = YOLO('yolov8n-seg.pt')
|
model = YOLO('yolov8n-seg.pt')
|
||||||
res = model([SOURCE, SOURCE])
|
res = model([SOURCE, SOURCE])
|
||||||
res[0].plot(show_conf=False)
|
res[0].plot(show_conf=False) # raises warning
|
||||||
|
res[0].plot(conf=True, boxes=False, masks=True)
|
||||||
res[0] = res[0].cpu().numpy()
|
res[0] = res[0].cpu().numpy()
|
||||||
print(res[0].path, res[0].masks.masks)
|
print(res[0].path, res[0].masks.masks)
|
||||||
|
|
||||||
model = YOLO('yolov8n.pt')
|
model = YOLO('yolov8n.pt')
|
||||||
res = model(SOURCE)
|
res = model(SOURCE)
|
||||||
res[0].plot()
|
res[0].plot()
|
||||||
@ -218,5 +218,5 @@ def test_result():
|
|||||||
|
|
||||||
model = YOLO('yolov8n-cls.pt')
|
model = YOLO('yolov8n-cls.pt')
|
||||||
res = model(SOURCE)
|
res = model(SOURCE)
|
||||||
res[0].plot()
|
res[0].plot(probs=False)
|
||||||
print(res[0].path)
|
print(res[0].path)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = '8.0.61'
|
__version__ = '8.0.62'
|
||||||
|
|
||||||
from ultralytics.hub import start
|
from ultralytics.hub import start
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
|
@ -9,8 +9,8 @@ from types import SimpleNamespace
|
|||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR,
|
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR,
|
||||||
IterableSimpleNamespace, __version__, checks, colorstr, get_settings, yaml_load,
|
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn,
|
||||||
yaml_print)
|
get_settings, yaml_load, yaml_print)
|
||||||
|
|
||||||
# Define valid tasks and modes
|
# Define valid tasks and modes
|
||||||
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
||||||
@ -71,7 +71,7 @@ CFG_INT_KEYS = ('epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic'
|
|||||||
'line_thickness', 'workspace', 'nbs', 'save_period')
|
'line_thickness', 'workspace', 'nbs', 'save_period')
|
||||||
CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect', 'cos_lr',
|
CFG_BOOL_KEYS = ('save', 'exist_ok', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect', 'cos_lr',
|
||||||
'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt',
|
'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt',
|
||||||
'save_conf', 'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms',
|
'save_conf', 'save_crop', 'show_labels', 'show_conf', 'visualize', 'augment', 'agnostic_nms',
|
||||||
'retina_masks', 'boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader')
|
'retina_masks', 'boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader')
|
||||||
|
|
||||||
|
|
||||||
@ -140,6 +140,22 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
|||||||
return IterableSimpleNamespace(**cfg)
|
return IterableSimpleNamespace(**cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_deprecation(custom):
|
||||||
|
"""
|
||||||
|
Hardcoded function to handle deprecated config keys
|
||||||
|
"""
|
||||||
|
|
||||||
|
for key in custom.copy().keys():
|
||||||
|
if key == 'hide_labels':
|
||||||
|
deprecation_warn(key, 'show_labels')
|
||||||
|
custom['show_labels'] = custom.pop('hide_labels') == 'False'
|
||||||
|
if key == 'hide_conf':
|
||||||
|
deprecation_warn(key, 'show_conf')
|
||||||
|
custom['show_conf'] = custom.pop('hide_conf') == 'False'
|
||||||
|
|
||||||
|
return custom
|
||||||
|
|
||||||
|
|
||||||
def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
|
def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
|
||||||
"""
|
"""
|
||||||
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
|
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
|
||||||
@ -149,6 +165,7 @@ def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
|
|||||||
- custom (Dict): a dictionary of custom configuration options
|
- custom (Dict): a dictionary of custom configuration options
|
||||||
- base (Dict): a dictionary of base configuration options
|
- base (Dict): a dictionary of base configuration options
|
||||||
"""
|
"""
|
||||||
|
custom = _handle_deprecation(custom)
|
||||||
base, custom = (set(x.keys()) for x in (base, custom))
|
base, custom = (set(x.keys()) for x in (base, custom))
|
||||||
mismatched = [x for x in custom if x not in base]
|
mismatched = [x for x in custom if x not in base]
|
||||||
if mismatched:
|
if mismatched:
|
||||||
|
@ -55,8 +55,8 @@ show: False # show results if possible
|
|||||||
save_txt: False # save results as .txt file
|
save_txt: False # save results as .txt file
|
||||||
save_conf: False # save results with confidence scores
|
save_conf: False # save results with confidence scores
|
||||||
save_crop: False # save cropped images with results
|
save_crop: False # save cropped images with results
|
||||||
hide_labels: False # hide labels
|
show_labels: True # show object labels in plots
|
||||||
hide_conf: False # hide confidence scores
|
show_conf: True # show object confidence scores in plots
|
||||||
vid_stride: 1 # video frame-rate stride
|
vid_stride: 1 # video frame-rate stride
|
||||||
line_thickness: 3 # bounding box thickness (pixels)
|
line_thickness: 3 # bounding box thickness (pixels)
|
||||||
visualize: False # visualize model features
|
visualize: False # visualize model features
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@ -77,7 +78,7 @@ class YOLO:
|
|||||||
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self._reset_callbacks()
|
self.callbacks = deepcopy(callbacks.default_callbacks)
|
||||||
self.predictor = None # reuse predictor
|
self.predictor = None # reuse predictor
|
||||||
self.model = None # model object
|
self.model = None # model object
|
||||||
self.trainer = None # trainer object
|
self.trainer = None # trainer object
|
||||||
@ -91,7 +92,7 @@ class YOLO:
|
|||||||
model = str(model).strip() # strip spaces
|
model = str(model).strip() # strip spaces
|
||||||
|
|
||||||
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
||||||
if model.startswith('https://hub.ultralytics.com/models/'):
|
if self.is_hub_model(model):
|
||||||
from ultralytics.hub.session import HUBTrainingSession
|
from ultralytics.hub.session import HUBTrainingSession
|
||||||
self.session = HUBTrainingSession(model)
|
self.session = HUBTrainingSession(model)
|
||||||
model = self.session.model_file
|
model = self.session.model_file
|
||||||
@ -112,6 +113,13 @@ class YOLO:
|
|||||||
name = self.__class__.__name__
|
name = self.__class__.__name__
|
||||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_hub_model(model):
|
||||||
|
return any((
|
||||||
|
model.startswith('https://hub.ultralytics.com/models/'),
|
||||||
|
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
|
||||||
|
(len(model) == 20 and not Path(model).exists() and not any(x in model for x in './\\')))) # MODELID
|
||||||
|
|
||||||
def _new(self, cfg: str, task=None, verbose=True):
|
def _new(self, cfg: str, task=None, verbose=True):
|
||||||
"""
|
"""
|
||||||
Initializes a new model and infers the task type from the model definitions.
|
Initializes a new model and infers the task type from the model definitions.
|
||||||
@ -220,8 +228,7 @@ class YOLO:
|
|||||||
if source is None:
|
if source is None:
|
||||||
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
|
||||||
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and \
|
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
|
||||||
('predict' in sys.argv or 'mode=predict' in sys.argv)
|
|
||||||
|
|
||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
overrides['conf'] = 0.25
|
overrides['conf'] = 0.25
|
||||||
@ -231,7 +238,7 @@ class YOLO:
|
|||||||
overrides['save'] = kwargs.get('save', False) # not save files by default
|
overrides['save'] = kwargs.get('save', False) # not save files by default
|
||||||
if not self.predictor:
|
if not self.predictor:
|
||||||
self.task = overrides.get('task') or self.task
|
self.task = overrides.get('task') or self.task
|
||||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
|
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
|
||||||
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
||||||
else: # only update args if predictor is already setup
|
else: # only update args if predictor is already setup
|
||||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||||
@ -380,19 +387,17 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||||
|
|
||||||
@staticmethod
|
def add_callback(self, event: str, func):
|
||||||
def add_callback(event: str, func):
|
|
||||||
"""
|
"""
|
||||||
Add callback
|
Add callback
|
||||||
"""
|
"""
|
||||||
callbacks.default_callbacks[event].append(func)
|
self.callbacks[event].append(func)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reset_ckpt_args(args):
|
def _reset_ckpt_args(args):
|
||||||
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
||||||
return {k: v for k, v in args.items() if k in include}
|
return {k: v for k, v in args.items() if k in include}
|
||||||
|
|
||||||
@staticmethod
|
def _reset_callbacks(self):
|
||||||
def _reset_callbacks():
|
|
||||||
for event in callbacks.default_callbacks.keys():
|
for event in callbacks.default_callbacks.keys():
|
||||||
callbacks.default_callbacks[event] = [callbacks.default_callbacks[event][0]]
|
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||||
|
@ -75,7 +75,7 @@ class BasePredictor:
|
|||||||
data_path (str): Path to data.
|
data_path (str): Path to data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
"""
|
"""
|
||||||
Initializes the BasePredictor class.
|
Initializes the BasePredictor class.
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ class BasePredictor:
|
|||||||
self.data_path = None
|
self.data_path = None
|
||||||
self.source_type = None
|
self.source_type = None
|
||||||
self.batch = None
|
self.batch = None
|
||||||
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
self.callbacks = defaultdict(list, _callbacks) if _callbacks else defaultdict(list, callbacks.default_callbacks)
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
def preprocess(self, img):
|
def preprocess(self, img):
|
||||||
|
@ -12,7 +12,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torchvision.transforms.functional as F
|
import torchvision.transforms.functional as F
|
||||||
|
|
||||||
from ultralytics.yolo.utils import LOGGER, SimpleClass, ops
|
from ultralytics.yolo.utils import LOGGER, SimpleClass, deprecation_warn, ops
|
||||||
from ultralytics.yolo.utils.plotting import Annotator, colors
|
from ultralytics.yolo.utils.plotting import Annotator, colors
|
||||||
from ultralytics.yolo.utils.torch_utils import TORCHVISION_0_10
|
from ultralytics.yolo.utils.torch_utils import TORCHVISION_0_10
|
||||||
|
|
||||||
@ -65,7 +65,7 @@ class Results(SimpleClass):
|
|||||||
self.boxes = Boxes(boxes, self.orig_shape)
|
self.boxes = Boxes(boxes, self.orig_shape)
|
||||||
if masks is not None:
|
if masks is not None:
|
||||||
self.masks = Masks(masks, self.orig_shape)
|
self.masks = Masks(masks, self.orig_shape)
|
||||||
if boxes is not None:
|
if probs is not None:
|
||||||
self.probs = probs
|
self.probs = probs
|
||||||
|
|
||||||
def cpu(self):
|
def cpu(self):
|
||||||
@ -100,46 +100,72 @@ class Results(SimpleClass):
|
|||||||
def keys(self):
|
def keys(self):
|
||||||
return [k for k in self._keys if getattr(self, k) is not None]
|
return [k for k in self._keys if getattr(self, k) is not None]
|
||||||
|
|
||||||
def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
def plot(
|
||||||
|
self,
|
||||||
|
conf=True,
|
||||||
|
line_width=None,
|
||||||
|
font_size=None,
|
||||||
|
font='Arial.ttf',
|
||||||
|
pil=False,
|
||||||
|
example='abc',
|
||||||
|
img=None,
|
||||||
|
labels=True,
|
||||||
|
boxes=True,
|
||||||
|
masks=True,
|
||||||
|
probs=True,
|
||||||
|
**kwargs # deprecated args TODO: remove support in 8.2
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
|
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
show_conf (bool): Whether to show the detection confidence score.
|
conf (bool): Whether to plot the detection confidence score.
|
||||||
line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
|
line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
|
||||||
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
|
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
|
||||||
font (str): The font to use for the text.
|
font (str): The font to use for the text.
|
||||||
pil (bool): Whether to return the image as a PIL Image.
|
pil (bool): Whether to return the image as a PIL Image.
|
||||||
example (str): An example string to display. Useful for indicating the expected format of the output.
|
example (str): An example string to display. Useful for indicating the expected format of the output.
|
||||||
|
img (numpy.ndarray): Plot to another image. if not, plot to original image.
|
||||||
|
labels (bool): Whether to plot the label of bounding boxes.
|
||||||
|
boxes (bool): Whether to plot the bounding boxes.
|
||||||
|
masks (bool): Whether to plot the masks.
|
||||||
|
probs (bool): Whether to plot classification probability
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(None) or (PIL.Image): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned.
|
(None) or (PIL.Image): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned.
|
||||||
"""
|
"""
|
||||||
annotator = Annotator(deepcopy(self.orig_img), line_width, font_size, font, pil, example)
|
# Deprecation warn TODO: remove in 8.2
|
||||||
boxes = self.boxes
|
if 'show_conf' in kwargs:
|
||||||
masks = self.masks
|
deprecation_warn('show_conf', 'conf')
|
||||||
probs = self.probs
|
conf = kwargs['show_conf']
|
||||||
|
assert type(conf) == bool, '`show_conf` should be of boolean type, i.e, show_conf=True/False'
|
||||||
|
|
||||||
|
annotator = Annotator(deepcopy(self.orig_img if img is None else img), line_width, font_size, font, pil,
|
||||||
|
example)
|
||||||
|
pred_boxes, show_boxes = self.boxes, boxes
|
||||||
|
pred_masks, show_masks = self.masks, masks
|
||||||
|
pred_probs, show_probs = self.probs, probs
|
||||||
names = self.names
|
names = self.names
|
||||||
hide_labels, hide_conf = False, not show_conf
|
if pred_boxes and show_boxes:
|
||||||
if boxes is not None:
|
for d in reversed(pred_boxes):
|
||||||
for d in reversed(boxes):
|
c, conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
||||||
c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
|
|
||||||
name = ('' if id is None else f'id:{id} ') + names[c]
|
name = ('' if id is None else f'id:{id} ') + names[c]
|
||||||
label = None if hide_labels else (name if hide_conf else f'{name} {conf:.2f}')
|
label = (name if not conf else f'{name} {conf:.2f}') if labels else None
|
||||||
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||||
|
|
||||||
if masks is not None:
|
if pred_masks and show_masks:
|
||||||
im = torch.as_tensor(annotator.im, dtype=torch.float16, device=masks.data.device).permute(2, 0, 1).flip(0)
|
im = torch.as_tensor(annotator.im, dtype=torch.float16, device=pred_masks.data.device).permute(2, 0,
|
||||||
|
1).flip(0)
|
||||||
if TORCHVISION_0_10:
|
if TORCHVISION_0_10:
|
||||||
im = F.resize(im.contiguous(), masks.data.shape[1:], antialias=True) / 255
|
im = F.resize(im.contiguous(), pred_masks.data.shape[1:], antialias=True) / 255
|
||||||
else:
|
else:
|
||||||
im = F.resize(im.contiguous(), masks.data.shape[1:]) / 255
|
im = F.resize(im.contiguous(), pred_masks.data.shape[1:]) / 255
|
||||||
annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im)
|
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in pred_boxes.cls], im_gpu=im)
|
||||||
|
|
||||||
if probs is not None:
|
if pred_probs is not None and show_probs:
|
||||||
n5 = min(len(names), 5)
|
n5 = min(len(names), 5)
|
||||||
top5i = probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
|
top5i = pred_probs.argsort(0, descending=True)[:n5].tolist() # top 5 indices
|
||||||
text = f"{', '.join(f'{names[j] if names else j} {probs[j]:.2f}' for j in top5i)}, "
|
text = f"{', '.join(f'{names[j] if names else j} {pred_probs[j]:.2f}' for j in top5i)}, "
|
||||||
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors
|
||||||
|
|
||||||
return np.asarray(annotator.im) if annotator.pil else annotator.im
|
return np.asarray(annotator.im) if annotator.pil else annotator.im
|
||||||
|
@ -624,7 +624,8 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.3'):
|
|||||||
|
|
||||||
# Check that settings keys and types match defaults
|
# Check that settings keys and types match defaults
|
||||||
correct = \
|
correct = \
|
||||||
settings.keys() == defaults.keys() \
|
settings \
|
||||||
|
and settings.keys() == defaults.keys() \
|
||||||
and all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) \
|
and all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) \
|
||||||
and check_version(settings['settings_version'], version)
|
and check_version(settings['settings_version'], version)
|
||||||
if not correct:
|
if not correct:
|
||||||
@ -646,6 +647,14 @@ def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
|
|||||||
yaml_save(file, SETTINGS)
|
yaml_save(file, SETTINGS)
|
||||||
|
|
||||||
|
|
||||||
|
def deprecation_warn(arg, new_arg, version=None):
|
||||||
|
if not version:
|
||||||
|
version = float(__version__[0:3]) + 0.2 # deprecate after 2nd major release
|
||||||
|
LOGGER.warning(
|
||||||
|
f'WARNING: `{arg}` is deprecated and will be removed in upcoming major release {version}. Use `{new_arg}` instead'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Run below code on yolo/utils init ------------------------------------------------------------------------------------
|
# Run below code on yolo/utils init ------------------------------------------------------------------------------------
|
||||||
|
|
||||||
# Check first-install steps
|
# Check first-install steps
|
||||||
|
@ -70,7 +70,7 @@ class DetectionPredictor(BasePredictor):
|
|||||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||||
if self.args.save or self.args.show: # Add bbox to image
|
if self.args.save or self.args.show: # Add bbox to image
|
||||||
name = ('' if id is None else f'id:{id} ') + self.model.names[c]
|
name = ('' if id is None else f'id:{id} ') + self.model.names[c]
|
||||||
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
|
label = (f'{name} {conf:.2f}' if self.args.show_conf else name) if self.args.show_labels else None
|
||||||
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||||
if self.args.save_crop:
|
if self.args.save_crop:
|
||||||
save_one_box(d.xyxy,
|
save_one_box(d.xyxy,
|
||||||
|
@ -84,7 +84,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|||||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||||
if self.args.save or self.args.show: # Add bbox to image
|
if self.args.save or self.args.show: # Add bbox to image
|
||||||
name = ('' if id is None else f'id:{id} ') + self.model.names[c]
|
name = ('' if id is None else f'id:{id} ') + self.model.names[c]
|
||||||
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
|
label = (f'{name} {conf:.2f}' if self.args.show_conf else name) if self.args.show_labels else None
|
||||||
if self.args.boxes:
|
if self.args.boxes:
|
||||||
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||||
if self.args.save_crop:
|
if self.args.save_crop:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user