mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-26 03:05:39 +08:00 
			
		
		
		
	ultralytics 8.0.42 DDP fix and Docs updates (#1065)
				
					
				
			Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Noobtoss <96134731+Noobtoss@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
		
							parent
							
								
									f6e393c1d2
								
							
						
					
					
						commit
						f2a7a29e53
					
				
							
								
								
									
										2
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							| @ -17,7 +17,7 @@ jobs: | ||||
|     strategy: | ||||
|       fail-fast: false | ||||
|       matrix: | ||||
|         os: [ubuntu-latest, windows-latest, macos-latest] | ||||
|         os: [ubuntu-latest] | ||||
|         python-version: ['3.10']  # requires python<=3.9 | ||||
|         model: [yolov8n] | ||||
|     steps: | ||||
|  | ||||
| @ -32,10 +32,11 @@ predictor's call method. | ||||
| 
 | ||||
| Results object consists of these component objects: | ||||
| 
 | ||||
| - `Results.boxes` : `Boxes` object with properties and methods for manipulating bboxes | ||||
| - `Results.masks` : `Masks` object used to index masks or to get segment coordinates. | ||||
| - `Results.probs` : `torch.Tensor` containing the class probabilities/logits. | ||||
| - `Results.orig_shape` : `tuple` containing the original image size as (height, width). | ||||
| - `Results.boxes`: `Boxes` object with properties and methods for manipulating bboxes | ||||
| - `Results.masks`: `Masks` object used to index masks or to get segment coordinates. | ||||
| - `Results.probs`: `torch.Tensor` containing the class probabilities/logits. | ||||
| - `Results.orig_img`: Original image loaded in memory. | ||||
| - `Results.path`: `Path` containing the path to input image | ||||
| 
 | ||||
| Each result is composed of torch.Tensor by default, in which you can easily use following functionality: | ||||
| 
 | ||||
| @ -94,18 +95,18 @@ results[0].probs  # cls prob, (num_class, ) | ||||
| 
 | ||||
| Class reference documentation for `Results` module and its components can be found [here](reference/results.md) | ||||
| 
 | ||||
| ## Visualizing results | ||||
| ## Plotting results | ||||
| 
 | ||||
| You can use `visualize()` function of `Result` object to get a visualization. It plots all components(boxes, masks, | ||||
| You can use `plot()` function of `Result` object to plot results on in image object. It plots all components(boxes, masks, | ||||
| classification logits, etc) found in the results object | ||||
| 
 | ||||
| ```python | ||||
|     res = model(img) | ||||
|     res_plotted = res[0].visualize() | ||||
|     cv2.imshow("result", res_plotted) | ||||
| res = model(img) | ||||
| res_plotted = res[0].plot() | ||||
| cv2.imshow("result", res_plotted) | ||||
| ``` | ||||
| 
 | ||||
| !!! example "`visualize()` arguments" | ||||
| !!! example "`plot()` arguments" | ||||
| 
 | ||||
|     `show_conf (bool)`: Show confidence | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										86
									
								
								docs/tasks/tracking.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								docs/tasks/tracking.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,86 @@ | ||||
| Object tracking is a task that involves identifying the location and class of objects, then assigning a unique ID to | ||||
| that detection in video streams. | ||||
| 
 | ||||
| The output of tracker is the same as detection with an added object ID. | ||||
| 
 | ||||
| ## Available Trackers | ||||
| 
 | ||||
| The following tracking algorithms have been implemented and can be enabled by passing `tracker=tracker_type.yaml` | ||||
| 
 | ||||
| * [BoT-SORT](https://github.com/NirAharon/BoT-SORT) - `botsort.yaml` | ||||
| * [ByteTrack](https://github.com/ifzhang/ByteTrack) - `bytetrack.yaml` | ||||
| 
 | ||||
| The default tracker is BoT-SORT. | ||||
| 
 | ||||
| ## Tracking | ||||
| 
 | ||||
| Use a trained YOLOv8n/YOLOv8n-seg model to run tracker on video streams. | ||||
| 
 | ||||
| !!! example "" | ||||
| 
 | ||||
|     === "Python" | ||||
|      | ||||
|         ```python | ||||
|         from ultralytics import YOLO | ||||
|          | ||||
|         # Load a model | ||||
|         model = YOLO("yolov8n.pt")  # load an official detection model | ||||
|         model = YOLO("yolov8n-seg.pt")  # load an official segmentation model | ||||
|         model = YOLO("path/to/best.pt")  # load a custom model | ||||
|          | ||||
|         # Track with the model | ||||
|         results = model.track(source="https://youtu.be/Zgi9g1ksQHc", show=True)  | ||||
|         results = model.track(source="https://youtu.be/Zgi9g1ksQHc", show=True, tracker="bytetrack.yaml")  | ||||
|         ``` | ||||
|     === "CLI" | ||||
|      | ||||
|         ```bash | ||||
|         yolo track model=yolov8n.pt source="https://youtu.be/Zgi9g1ksQHc"  # official detection model | ||||
|         yolo track model=yolov8n-seg.pt source=...   # official segmentation model | ||||
|         yolo track model=path/to/best.pt source=...  # custom model | ||||
|         yolo track model=path/to/best.pt  tracker="bytetrack.yaml" # bytetrack tracker | ||||
| 
 | ||||
|         ``` | ||||
| 
 | ||||
| As in the above usage, we support both the detection and segmentation models for tracking and the only thing you need to do is loading the corresponding(detection or segmentation) model. | ||||
| 
 | ||||
| ## Configuration | ||||
| ### Tracking | ||||
| Tracking shares the configuration with predict, i.e `conf`, `iou`, `show`. More configurations please refer to [predict page](https://docs.ultralytics.com/cfg/#prediction). | ||||
| !!! example "" | ||||
| 
 | ||||
|     === "Python" | ||||
|      | ||||
|         ```python | ||||
|         from ultralytics import YOLO | ||||
|          | ||||
|         model = YOLO("yolov8n.pt") | ||||
|         results = model.track(source="https://youtu.be/Zgi9g1ksQHc", conf=0.3, iou=0.5, show=True)  | ||||
|         ``` | ||||
|     === "CLI" | ||||
|      | ||||
|         ```bash | ||||
|         yolo track model=yolov8n.pt source="https://youtu.be/Zgi9g1ksQHc" conf=0.3, iou=0.5 show | ||||
| 
 | ||||
|         ``` | ||||
| 
 | ||||
| ### Tracker | ||||
| We also support using a modified tracker config file, just copy a config file i.e `custom_tracker.yaml` from [ultralytics/tracker/cfg](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/tracker/cfg) and modify any configurations(expect the `tracker_type`) you need to. | ||||
| !!! example "" | ||||
| 
 | ||||
|     === "Python" | ||||
|      | ||||
|         ```python | ||||
|         from ultralytics import YOLO | ||||
|          | ||||
|         model = YOLO("yolov8n.pt") | ||||
|         results = model.track(source="https://youtu.be/Zgi9g1ksQHc", tracker='custom_tracker.yaml')  | ||||
|         ``` | ||||
|     === "CLI" | ||||
|      | ||||
|         ```bash | ||||
|         yolo track model=yolov8n.pt source="https://youtu.be/Zgi9g1ksQHc" tracker='custom_tracker.yaml' | ||||
| 
 | ||||
|         ``` | ||||
| Please refer to [ultralytics/tracker/cfg](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/tracker/cfg) page.  | ||||
| 
 | ||||
| @ -2,10 +2,11 @@ This is a list of real-world applications and walkthroughs. These can be folders | ||||
| 
 | ||||
| ## Ultralytics YOLO example applications | ||||
| 
 | ||||
| | Title                                                           | Format             | Contributor                                         | | ||||
| | --------------------------------------------------------------- | ------------------ | --------------------------------------------------- | | ||||
| | [Yolov8/yolov5 ONNX Inference with C++](./YOLOv8-CPP-Inference) | C++/ONNX           | [Justas Bartnykas](https://github.com/JustasBart)   | | ||||
| | [YOLOv8-OpenCV-ONNX-Python](./YOLOv8-OpenCV-ONNX-Python)        | OpenCV/Python/ONNX | [Farid Inawan](https://github.com/frdteknikelektro) | | ||||
| | Title                                                                    | Format             | Contributor                                         | | ||||
| | ------------------------------------------------------------------------ | ------------------ | --------------------------------------------------- | | ||||
| | [YOLO ONNX detection Inference with C++](./YOLOv8_CPP_Inference)         | C++/ONNX           | [Justas Bartnykas](https://github.com/JustasBart)   | | ||||
| | [YOLO OpenCV ONNX detection Python](./YOLOv8-OpenCV-ONNX-Python)         | OpenCV/Python/ONNX | [Farid Inawan](https://github.com/frdteknikelektro) | | ||||
| | [YOLO .Net ONNX detection C#](https://www.nuget.org/packages/Yolov8.Net) | C# .Net            | [Samuel Stainback](https://github.com/sstainba)     | | ||||
| 
 | ||||
| ## How can you contribute ? | ||||
| 
 | ||||
|  | ||||
| @ -123,7 +123,7 @@ | ||||
|             "Downloading https://ultralytics.com/images/zidane.jpg to zidane.jpg...\n", | ||||
|             "100% 165k/165k [00:00<00:00, 87.4MB/s]\n", | ||||
|             "image 1/1 /content/zidane.jpg: 384x640 2 persons, 1 tie, 13.3ms\n", | ||||
|             "Speed: 0.5ms pre-process, 13.3ms inference, 43.5ms postprocess per image at shape (1, 3, 640, 640)\n", | ||||
|             "Speed: 0.5ms preprocess, 13.3ms inference, 43.5ms postprocess per image at shape (1, 3, 640, 640)\n", | ||||
|             "Results saved to \u001b[1mruns/detect/predict\u001b[0m\n" | ||||
|           ] | ||||
|         } | ||||
| @ -268,7 +268,7 @@ | ||||
|             "              scissors        128          1          1          0      0.249     0.0746\n", | ||||
|             "            teddy bear        128         21      0.877      0.333      0.591      0.394\n", | ||||
|             "            toothbrush        128          5      0.743        0.6      0.638      0.374\n", | ||||
|             "Speed: 2.4ms pre-process, 7.8ms inference, 0.0ms loss, 3.3ms post-process per image\n" | ||||
|             "Speed: 2.4ms preprocess, 7.8ms inference, 0.0ms loss, 3.3ms postprocess per image\n" | ||||
|           ] | ||||
|         } | ||||
|       ] | ||||
| @ -439,7 +439,7 @@ | ||||
|             "              scissors        128          1          1          0      0.142     0.0426\n", | ||||
|             "            teddy bear        128         21      0.587      0.476       0.63      0.458\n", | ||||
|             "            toothbrush        128          5      0.784      0.736      0.898      0.544\n", | ||||
|             "Speed: 2.0ms pre-process, 4.0ms inference, 0.0ms loss, 2.5ms post-process per image\n", | ||||
|             "Speed: 2.0ms preprocess, 4.0ms inference, 0.0ms loss, 2.5ms postprocess per image\n", | ||||
|             "Results saved to \u001b[1mruns/detect/train\u001b[0m\n" | ||||
|           ] | ||||
|         } | ||||
|  | ||||
| @ -105,6 +105,7 @@ nav: | ||||
|   - Tasks: | ||||
|       - Detection: tasks/detection.md | ||||
|       - Segmentation: tasks/segmentation.md | ||||
|       - Multi-Object Tracking: tasks/tracking.md | ||||
|       - Classification: tasks/classification.md | ||||
|   - Usage: | ||||
|       - CLI: cli.md | ||||
|  | ||||
| @ -170,15 +170,16 @@ def test_predict_callback_and_setup(): | ||||
| def test_result(): | ||||
|     model = YOLO('yolov8n-seg.pt') | ||||
|     res = model([SOURCE, SOURCE]) | ||||
|     res[0].numpy() | ||||
|     res[0].cpu().numpy() | ||||
|     resimg = res[0].visualize(show_conf=False) | ||||
|     print(resimg) | ||||
|     res[0].plot(show_conf=False) | ||||
|     print(res[0].path) | ||||
| 
 | ||||
|     model = YOLO('yolov8n.pt') | ||||
|     res = model(SOURCE) | ||||
|     res[0].visualize() | ||||
|     res[0].plot() | ||||
|     print(res[0].path) | ||||
| 
 | ||||
|     model = YOLO('yolov8n-cls.pt') | ||||
|     res = model(SOURCE) | ||||
|     res[0].visualize() | ||||
|     res[0].plot() | ||||
|     print(res[0].path) | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| 
 | ||||
| __version__ = '8.0.41' | ||||
| __version__ = '8.0.42' | ||||
| 
 | ||||
| from ultralytics.yolo.engine.model import YOLO | ||||
| from ultralytics.yolo.utils.checks import check_yolo as checks | ||||
|  | ||||
| @ -232,6 +232,3 @@ class Detections: | ||||
| 
 | ||||
|     def __repr__(self): | ||||
|         return f'YOLOv8 {self.__class__} instance\n' + self.__str__() | ||||
| 
 | ||||
| 
 | ||||
| print('works') | ||||
|  | ||||
| @ -381,7 +381,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): | ||||
|         return ensemble[-1] | ||||
| 
 | ||||
|     # Return ensemble | ||||
|     print(f'Ensemble created with {weights}\n') | ||||
|     LOGGER.info(f'Ensemble created with {weights}\n') | ||||
|     for k in 'names', 'nc', 'yaml': | ||||
|         setattr(ensemble, k, getattr(ensemble[0], k)) | ||||
|     ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride | ||||
|  | ||||
| @ -16,7 +16,7 @@ model = YOLO("yolov8n.pt")  # or a segmentation model .i.e yolov8n-seg.pt | ||||
| model.track( | ||||
|     source="video/streams", | ||||
|     stream=True, | ||||
|     tracker="botsort.yaml/bytetrack.yaml", | ||||
|     tracker="botsort.yaml",  # or 'bytetrack.yaml' | ||||
|     ..., | ||||
| ) | ||||
| ``` | ||||
|  | ||||
| @ -1 +1,3 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| 
 | ||||
| from .trackers import BOTSORT, BYTETracker | ||||
|  | ||||
| @ -1,3 +1,6 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| # Default YOLO tracker settings for BoT-SORT tracker https://github.com/NirAharon/BoT-SORT | ||||
| 
 | ||||
| tracker_type: botsort  # tracker type, ['botsort', 'bytetrack'] | ||||
| track_high_thresh: 0.5  # threshold for the first association | ||||
| track_low_thresh: 0.1  # threshold for the second association | ||||
| @ -7,7 +10,7 @@ match_thresh: 0.8  # threshold for matching tracks | ||||
| # min_box_area: 10  # threshold for min box areas(for tracker evaluation, not used for now) | ||||
| # mot20: False  # for tracker evaluation(not used for now) | ||||
| 
 | ||||
| # Botsort settings | ||||
| # BoT-SORT settings | ||||
| cmc_method: sparseOptFlow  # method of global motion compensation | ||||
| # ReID model related thresh (not supported yet) | ||||
| proximity_thresh: 0.5 | ||||
|  | ||||
| @ -1,3 +1,6 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| # Default YOLO tracker settings for ByteTrack tracker https://github.com/ifzhang/ByteTrack | ||||
| 
 | ||||
| tracker_type: bytetrack  # tracker type, ['botsort', 'bytetrack'] | ||||
| track_high_thresh: 0.5  # threshold for the first association | ||||
| track_low_thresh: 0.1  # threshold for the second association | ||||
|  | ||||
| @ -1,3 +1,5 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| 
 | ||||
| import torch | ||||
| 
 | ||||
| from ultralytics.tracker import BOTSORT, BYTETracker | ||||
|  | ||||
| @ -1,2 +1,4 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| 
 | ||||
| from .bot_sort import BOTSORT | ||||
| from .byte_tracker import BYTETracker | ||||
|  | ||||
| @ -1,3 +1,5 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| 
 | ||||
| from collections import OrderedDict | ||||
| 
 | ||||
| import numpy as np | ||||
|  | ||||
| @ -1,3 +1,5 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| 
 | ||||
| from collections import deque | ||||
| 
 | ||||
| import numpy as np | ||||
| @ -97,7 +99,7 @@ class BOTSORT(BYTETracker): | ||||
|         self.appearance_thresh = args.appearance_thresh | ||||
| 
 | ||||
|         if args.with_reid: | ||||
|             # haven't supported bot-sort(reid) yet | ||||
|             # haven't supported BoT-SORT(reid) yet | ||||
|             self.encoder = None | ||||
|         # self.gmc = GMC(method=args.cmc_method, verbose=[args.name, args.ablation]) | ||||
|         self.gmc = GMC(method=args.cmc_method) | ||||
|  | ||||
| @ -1,3 +1,5 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| 
 | ||||
| import numpy as np | ||||
| 
 | ||||
| from ..utils import matching | ||||
|  | ||||
| @ -1,9 +1,13 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| 
 | ||||
| import copy | ||||
| 
 | ||||
| import cv2 | ||||
| import matplotlib.pyplot as plt | ||||
| import numpy as np | ||||
| 
 | ||||
| from ultralytics.yolo.utils import LOGGER | ||||
| 
 | ||||
| 
 | ||||
| class GMC: | ||||
| 
 | ||||
| @ -108,7 +112,7 @@ class GMC: | ||||
|         try: | ||||
|             (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1) | ||||
|         except Exception as e: | ||||
|             print(f'Warning: find transform failed. Set warp as identity {e}') | ||||
|             LOGGER.warning(f'WARNING: find transform failed. Set warp as identity {e}') | ||||
| 
 | ||||
|         return H | ||||
| 
 | ||||
| @ -229,7 +233,7 @@ class GMC: | ||||
|                 H[0, 2] *= self.downscale | ||||
|                 H[1, 2] *= self.downscale | ||||
|         else: | ||||
|             print('Warning: not enough matching points') | ||||
|             LOGGER.warning('WARNING: not enough matching points') | ||||
| 
 | ||||
|         # Store to next iteration | ||||
|         self.prevFrame = frame.copy() | ||||
| @ -288,7 +292,7 @@ class GMC: | ||||
|                 H[0, 2] *= self.downscale | ||||
|                 H[1, 2] *= self.downscale | ||||
|         else: | ||||
|             print('Warning: not enough matching points') | ||||
|             LOGGER.warning('WARNING: not enough matching points') | ||||
| 
 | ||||
|         # Store to next iteration | ||||
|         self.prevFrame = frame.copy() | ||||
|  | ||||
| @ -1,3 +1,5 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| 
 | ||||
| import numpy as np | ||||
| import scipy.linalg | ||||
| 
 | ||||
| @ -234,7 +236,7 @@ class KalmanFilterXYAH: | ||||
| 
 | ||||
| class KalmanFilterXYWH: | ||||
|     """ | ||||
|     For bot-sort | ||||
|     For BoT-SORT | ||||
|     A simple Kalman filter for tracking bounding boxes in image space. | ||||
| 
 | ||||
|     The 8-dimensional state space | ||||
|  | ||||
| @ -1,3 +1,5 @@ | ||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | ||||
| 
 | ||||
| import lap | ||||
| import numpy as np | ||||
| import scipy | ||||
|  | ||||
| @ -1136,11 +1136,11 @@ class HUBDatasetStats(): | ||||
|         # Save, print and return | ||||
|         if save: | ||||
|             stats_path = self.hub_dir / 'stats.json' | ||||
|             print(f'Saving {stats_path.resolve()}...') | ||||
|             LOGGER.info(f'Saving {stats_path.resolve()}...') | ||||
|             with open(stats_path, 'w') as f: | ||||
|                 json.dump(self.stats, f)  # save stats.json | ||||
|         if verbose: | ||||
|             print(json.dumps(self.stats, indent=2, sort_keys=False)) | ||||
|             LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False)) | ||||
|         return self.stats | ||||
| 
 | ||||
|     def process_images(self): | ||||
| @ -1154,7 +1154,7 @@ class HUBDatasetStats(): | ||||
|             with ThreadPool(NUM_THREADS) as pool: | ||||
|                 for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=total, desc=desc): | ||||
|                     pass | ||||
|         print(f'Done. All images saved to {self.im_dir}') | ||||
|         LOGGER.info(f'Done. All images saved to {self.im_dir}') | ||||
|         return self.im_dir | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -75,7 +75,6 @@ from ultralytics.yolo.utils.files import file_size | ||||
| from ultralytics.yolo.utils.ops import Profile | ||||
| from ultralytics.yolo.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode | ||||
| 
 | ||||
| CUDA = torch.cuda.is_available() | ||||
| ARM64 = platform.machine() in ('arm64', 'aarch64') | ||||
| 
 | ||||
| 
 | ||||
| @ -324,7 +323,7 @@ class Exporter: | ||||
|         # Simplify | ||||
|         if self.args.simplify: | ||||
|             try: | ||||
|                 check_requirements(('onnxsim', 'onnxruntime-gpu' if CUDA else 'onnxruntime')) | ||||
|                 check_requirements(('onnxsim', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime')) | ||||
|                 import onnxsim | ||||
| 
 | ||||
|                 LOGGER.info(f'{prefix} simplifying with onnxsim {onnxsim.__version__}...') | ||||
| @ -506,10 +505,12 @@ class Exporter: | ||||
|         try: | ||||
|             import tensorflow as tf  # noqa | ||||
|         except ImportError: | ||||
|             check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if CUDA else '-cpu'}") | ||||
|             check_requirements( | ||||
|                 f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if torch.cuda.is_available() else '-cpu'}" | ||||
|             ) | ||||
|             import tensorflow as tf  # noqa | ||||
|         check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support', | ||||
|                             'onnxruntime-gpu' if CUDA else 'onnxruntime'), | ||||
|                             'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'), | ||||
|                            cmds='--extra-index-url https://pypi.ngc.nvidia.com') | ||||
| 
 | ||||
|         LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') | ||||
|  | ||||
| @ -32,7 +32,7 @@ class YOLO: | ||||
|         YOLO (You Only Look Once) object detection model. | ||||
| 
 | ||||
|         Args: | ||||
|             model (str or Path): Path to the model file to load or create. | ||||
|             model (str, Path): Path to the model file to load or create. | ||||
|             type (str): Type/version of models to use. Defaults to "v8". | ||||
| 
 | ||||
|         Attributes: | ||||
| @ -62,7 +62,7 @@ class YOLO: | ||||
|             predict(source=None, stream=False, **kwargs): Perform prediction using the YOLO model. | ||||
| 
 | ||||
|         Returns: | ||||
|             List[ultralytics.yolo.engine.results.Results]: The prediction results. | ||||
|             list(ultralytics.yolo.engine.results.Results): The prediction results. | ||||
|         """ | ||||
| 
 | ||||
|     def __init__(self, model='yolov8n.pt', type='v8') -> None: | ||||
| @ -114,6 +114,7 @@ class YOLO: | ||||
|         self.task = guess_model_task(cfg_dict) | ||||
|         self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task() | ||||
|         self.model = self.ModelClass(cfg_dict, verbose=verbose and RANK == -1)  # initialize | ||||
|         self.overrides['model'] = self.cfg | ||||
| 
 | ||||
|     def _load(self, weights: str): | ||||
|         """ | ||||
| @ -204,7 +205,7 @@ class YOLO: | ||||
|     def track(self, source=None, stream=False, **kwargs): | ||||
|         from ultralytics.tracker.track import register_tracker | ||||
|         register_tracker(self) | ||||
|         # bytetrack-based method needs low confidence predictions as input | ||||
|         # ByteTrack-based method needs low confidence predictions as input | ||||
|         conf = kwargs.get('conf') or 0.1 | ||||
|         kwargs['conf'] = conf | ||||
|         kwargs['mode'] = 'track' | ||||
|  | ||||
| @ -92,6 +92,7 @@ class BasePredictor: | ||||
|         self.annotator = None | ||||
|         self.data_path = None | ||||
|         self.source_type = None | ||||
|         self.batch = None | ||||
|         self.callbacks = defaultdict(list, callbacks.default_callbacks)  # add callbacks | ||||
|         callbacks.add_integration_callbacks(self) | ||||
| 
 | ||||
|  | ||||
| @ -28,13 +28,14 @@ class Results: | ||||
| 
 | ||||
|         """ | ||||
| 
 | ||||
|     def __init__(self, boxes=None, masks=None, probs=None, orig_img=None, names=None) -> None: | ||||
|     def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None) -> None: | ||||
|         self.orig_img = orig_img | ||||
|         self.orig_shape = orig_img.shape[:2] | ||||
|         self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None  # native size boxes | ||||
|         self.masks = Masks(masks, self.orig_shape) if masks is not None else None  # native size or imgsz masks | ||||
|         self.probs = probs if probs is not None else None | ||||
|         self.names = names | ||||
|         self.path = path | ||||
|         self.comp = ['boxes', 'masks', 'probs'] | ||||
| 
 | ||||
|     def pandas(self): | ||||
| @ -42,7 +43,7 @@ class Results: | ||||
|         # TODO masks.pandas + boxes.pandas + cls.pandas | ||||
| 
 | ||||
|     def __getitem__(self, idx): | ||||
|         r = Results(orig_img=self.orig_img) | ||||
|         r = Results(orig_img=self.orig_img, path=self.path, names=self.names) | ||||
|         for item in self.comp: | ||||
|             if getattr(self, item) is None: | ||||
|                 continue | ||||
| @ -58,7 +59,7 @@ class Results: | ||||
|             self.probs = probs | ||||
| 
 | ||||
|     def cpu(self): | ||||
|         r = Results(orig_img=self.orig_img) | ||||
|         r = Results(orig_img=self.orig_img, path=self.path, names=self.names) | ||||
|         for item in self.comp: | ||||
|             if getattr(self, item) is None: | ||||
|                 continue | ||||
| @ -66,7 +67,7 @@ class Results: | ||||
|         return r | ||||
| 
 | ||||
|     def numpy(self): | ||||
|         r = Results(orig_img=self.orig_img) | ||||
|         r = Results(orig_img=self.orig_img, path=self.path, names=self.names) | ||||
|         for item in self.comp: | ||||
|             if getattr(self, item) is None: | ||||
|                 continue | ||||
| @ -74,7 +75,7 @@ class Results: | ||||
|         return r | ||||
| 
 | ||||
|     def cuda(self): | ||||
|         r = Results(orig_img=self.orig_img) | ||||
|         r = Results(orig_img=self.orig_img, path=self.path, names=self.names) | ||||
|         for item in self.comp: | ||||
|             if getattr(self, item) is None: | ||||
|                 continue | ||||
| @ -82,7 +83,7 @@ class Results: | ||||
|         return r | ||||
| 
 | ||||
|     def to(self, *args, **kwargs): | ||||
|         r = Results(orig_img=self.orig_img) | ||||
|         r = Results(orig_img=self.orig_img, path=self.path, names=self.names) | ||||
|         for item in self.comp: | ||||
|             if getattr(self, item) is None: | ||||
|                 continue | ||||
| @ -123,7 +124,7 @@ class Results: | ||||
|                 orig_shape (tuple, optional): Original image size. | ||||
|             """) | ||||
| 
 | ||||
|     def visualize(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'): | ||||
|     def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'): | ||||
|         """ | ||||
|         Plots the given result on an input RGB image. Accepts cv2(numpy) or PIL Image | ||||
| 
 | ||||
| @ -146,9 +147,9 @@ class Results: | ||||
|                 annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) | ||||
| 
 | ||||
|         if masks is not None: | ||||
|             im_gpu = torch.as_tensor(img, dtype=torch.float16).permute(2, 0, 1).flip(0).contiguous() | ||||
|             im_gpu = F.resize(im_gpu, masks.data.shape[1:]) / 255 | ||||
|             annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im_gpu) | ||||
|             im = torch.as_tensor(img, dtype=torch.float16, device=masks.data.device).permute(2, 0, 1).flip(0) | ||||
|             im = F.resize(im.contiguous(), masks.data.shape[1:]) / 255 | ||||
|             annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im) | ||||
| 
 | ||||
|         if logits is not None: | ||||
|             top5i = logits.argsort(0, descending=True)[:5].tolist()  # top 5 indices | ||||
| @ -371,24 +372,3 @@ class Masks: | ||||
|             Properties: | ||||
|                 segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks. | ||||
|             """) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     # test examples | ||||
|     results = Results(boxes=torch.randn((2, 6)), masks=torch.randn((2, 160, 160)), orig_shape=[640, 640]) | ||||
|     results = results.cuda() | ||||
|     print('--cuda--pass--') | ||||
|     results = results.cpu() | ||||
|     print('--cpu--pass--') | ||||
|     results = results.to('cuda:0') | ||||
|     print('--to-cuda--pass--') | ||||
|     results = results.to('cpu') | ||||
|     print('--to-cpu--pass--') | ||||
|     results = results.numpy() | ||||
|     print('--numpy--pass--') | ||||
|     # box = Boxes(boxes=torch.randn((2, 6)), orig_shape=[5, 5]) | ||||
|     # box = box.cuda() | ||||
|     # box = box.cpu() | ||||
|     # box = box.numpy() | ||||
|     # for b in box: | ||||
|     #     print(b) | ||||
|  | ||||
| @ -11,7 +11,7 @@ import numpy as np | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| 
 | ||||
| from ultralytics.yolo.utils import TryExcept | ||||
| from ultralytics.yolo.utils import LOGGER, TryExcept | ||||
| 
 | ||||
| 
 | ||||
| # boxes | ||||
| @ -260,7 +260,7 @@ class ConfusionMatrix: | ||||
| 
 | ||||
|     def print(self): | ||||
|         for i in range(self.nc + 1): | ||||
|             print(' '.join(map(str, self.matrix[i]))) | ||||
|             LOGGER.info(' '.join(map(str, self.matrix[i]))) | ||||
| 
 | ||||
| 
 | ||||
| def smooth(y, f=0.05): | ||||
|  | ||||
| @ -12,7 +12,7 @@ import torch | ||||
| from PIL import Image, ImageDraw, ImageFont | ||||
| from PIL import __version__ as pil_version | ||||
| 
 | ||||
| from ultralytics.yolo.utils import threaded | ||||
| from ultralytics.yolo.utils import LOGGER, threaded | ||||
| 
 | ||||
| from .checks import check_font, check_version, is_ascii | ||||
| from .files import increment_path | ||||
| @ -300,7 +300,7 @@ def plot_results(file='path/to/results.csv', dir='', segment=False): | ||||
|                 # if j in [8, 9, 10]:  # share train and val loss y axes | ||||
|                 #     ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) | ||||
|         except Exception as e: | ||||
|             print(f'Warning: Plotting error for {f}: {e}') | ||||
|             LOGGER.warning(f'WARNING: Plotting error for {f}: {e}') | ||||
|     ax[1].legend() | ||||
|     fig.savefig(save_dir / 'results.png', dpi=200) | ||||
|     plt.close() | ||||
|  | ||||
| @ -167,11 +167,12 @@ def model_info(model, verbose=False, imgsz=640): | ||||
|     n_p = get_num_params(model) | ||||
|     n_g = get_num_gradients(model)  # number gradients | ||||
|     if verbose: | ||||
|         print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}") | ||||
|         LOGGER.info( | ||||
|             f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}") | ||||
|         for i, (name, p) in enumerate(model.named_parameters()): | ||||
|             name = name.replace('module_list.', '') | ||||
|             print('%5g %40s %9s %12g %20s %10.3g %10.3g' % | ||||
|                   (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) | ||||
|             LOGGER.info('%5g %40s %9s %12g %20s %10.3g %10.3g' % | ||||
|                         (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) | ||||
| 
 | ||||
|     flops = get_flops(model, imgsz) | ||||
|     fused = ' (fused)' if model.is_fused() else '' | ||||
| @ -362,8 +363,8 @@ def profile(input, ops, n=10, device=None): | ||||
|     results = [] | ||||
|     if not isinstance(device, torch.device): | ||||
|         device = select_device(device) | ||||
|     print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}" | ||||
|           f"{'input':>24s}{'output':>24s}") | ||||
|     LOGGER.info(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}" | ||||
|                 f"{'input':>24s}{'output':>24s}") | ||||
| 
 | ||||
|     for x in input if isinstance(input, list) else [input]: | ||||
|         x = x.to(device) | ||||
| @ -393,10 +394,10 @@ def profile(input, ops, n=10, device=None): | ||||
|                 mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0  # (GB) | ||||
|                 s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y))  # shapes | ||||
|                 p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0  # parameters | ||||
|                 print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}') | ||||
|                 LOGGER.info(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}') | ||||
|                 results.append([p, flops, mem, tf, tb, s_in, s_out]) | ||||
|             except Exception as e: | ||||
|                 print(e) | ||||
|                 LOGGER.info(e) | ||||
|                 results.append(None) | ||||
|             torch.cuda.empty_cache() | ||||
|     return results | ||||
|  | ||||
| @ -22,7 +22,9 @@ class ClassificationPredictor(BasePredictor): | ||||
|         results = [] | ||||
|         for i, pred in enumerate(preds): | ||||
|             orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img | ||||
|             results.append(Results(probs=pred, orig_img=orig_img, names=self.model.names)) | ||||
|             path, _, _, _, _ = self.batch | ||||
|             img_path = path[i] if isinstance(path, list) else path | ||||
|             results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred)) | ||||
| 
 | ||||
|         return results | ||||
| 
 | ||||
|  | ||||
| @ -32,7 +32,9 @@ class DetectionPredictor(BasePredictor): | ||||
|             orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img | ||||
|             shape = orig_img.shape | ||||
|             pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() | ||||
|             results.append(Results(boxes=pred, orig_img=orig_img, names=self.model.names)) | ||||
|             path, _, _, _, _ = self.batch | ||||
|             img_path = path[i] if isinstance(path, list) else path | ||||
|             results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred)) | ||||
|         return results | ||||
| 
 | ||||
|     def write_results(self, idx, results, batch): | ||||
|  | ||||
| @ -24,9 +24,10 @@ class SegmentationPredictor(DetectionPredictor): | ||||
|         for i, pred in enumerate(p): | ||||
|             orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img | ||||
|             shape = orig_img.shape | ||||
|             if not len(pred): | ||||
|                 results.append(Results(boxes=pred[:, :6], orig_img=orig_img, | ||||
|                                        names=self.model.names))  # save empty boxes | ||||
|             path, _, _, _, _ = self.batch | ||||
|             img_path = path[i] if isinstance(path, list) else path | ||||
|             if not len(pred):  # save empty boxes | ||||
|                 results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) | ||||
|                 continue | ||||
|             if self.args.retina_masks: | ||||
|                 pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() | ||||
| @ -34,7 +35,8 @@ class SegmentationPredictor(DetectionPredictor): | ||||
|             else: | ||||
|                 masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC | ||||
|                 pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() | ||||
|             results.append(Results(boxes=pred[:, :6], masks=masks, orig_img=orig_img, names=self.model.names)) | ||||
|             results.append( | ||||
|                 Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) | ||||
|         return results | ||||
| 
 | ||||
|     def write_results(self, idx, results, batch): | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Glenn Jocher
						Glenn Jocher