mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
ultralytics 8.0.71
updates and fixes (#1907)
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Pavel Bugneac <50273042+pavelbugneac@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c38b17a0d8
commit
4e997013bc
6
.github/ISSUE_TEMPLATE/config.yml
vendored
6
.github/ISSUE_TEMPLATE/config.yml
vendored
@ -6,6 +6,6 @@ contact_links:
|
|||||||
- name: 💬 Forum
|
- name: 💬 Forum
|
||||||
url: https://community.ultralytics.com/
|
url: https://community.ultralytics.com/
|
||||||
about: Ask on Ultralytics Community Forum
|
about: Ask on Ultralytics Community Forum
|
||||||
- name: Stack Overflow
|
- name: 🎧 Discord
|
||||||
url: https://stackoverflow.com/search?q=YOLOv8
|
url: https://discord.gg/n6cFeSPZdD
|
||||||
about: Ask on Stack Overflow with 'YOLOv8' tag
|
about: Ask on Ultralytics Discord
|
||||||
|
@ -23,7 +23,7 @@ full list of export arguments.
|
|||||||
```python
|
```python
|
||||||
from ultralytics.yolo.utils.benchmarks import benchmark
|
from ultralytics.yolo.utils.benchmarks import benchmark
|
||||||
|
|
||||||
# Benchmark
|
# Benchmark on GPU
|
||||||
benchmark(model='yolov8n.pt', imgsz=640, half=False, device=0)
|
benchmark(model='yolov8n.pt', imgsz=640, half=False, device=0)
|
||||||
```
|
```
|
||||||
=== "CLI"
|
=== "CLI"
|
||||||
@ -63,3 +63,5 @@ Benchmarks will attempt to run automatically on all possible export formats belo
|
|||||||
| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n_edgetpu.tflite` | ✅ |
|
| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n_edgetpu.tflite` | ✅ |
|
||||||
| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n_web_model/` | ✅ |
|
| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n_web_model/` | ✅ |
|
||||||
| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` | ✅ |
|
| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` | ✅ |
|
||||||
|
|
||||||
|
See full `export` details in the [Export](https://docs.ultralytics.com/modes/export/) page.
|
@ -90,7 +90,6 @@ task.
|
|||||||
| `dfl` | `1.5` | dfl loss gain |
|
| `dfl` | `1.5` | dfl loss gain |
|
||||||
| `pose` | `12.0` | pose loss gain (pose-only) |
|
| `pose` | `12.0` | pose loss gain (pose-only) |
|
||||||
| `kobj` | `2.0` | keypoint obj loss gain (pose-only) |
|
| `kobj` | `2.0` | keypoint obj loss gain (pose-only) |
|
||||||
| `fl_gamma` | `0.0` | focal loss gamma (efficientDet default gamma=1.5) |
|
|
||||||
| `label_smoothing` | `0.0` | label smoothing (fraction) |
|
| `label_smoothing` | `0.0` | label smoothing (fraction) |
|
||||||
| `nbs` | `64` | nominal batch size |
|
| `nbs` | `64` | nominal batch size |
|
||||||
| `overlap_mask` | `True` | masks should overlap during training (segment train only) |
|
| `overlap_mask` | `True` | masks should overlap during training (segment train only) |
|
||||||
|
@ -112,7 +112,6 @@ The training settings for YOLO models encompass various hyperparameters and conf
|
|||||||
| `dfl` | `1.5` | dfl loss gain |
|
| `dfl` | `1.5` | dfl loss gain |
|
||||||
| `pose` | `12.0` | pose loss gain (pose-only) |
|
| `pose` | `12.0` | pose loss gain (pose-only) |
|
||||||
| `kobj` | `2.0` | keypoint obj loss gain (pose-only) |
|
| `kobj` | `2.0` | keypoint obj loss gain (pose-only) |
|
||||||
| `fl_gamma` | `0.0` | focal loss gamma (efficientDet default gamma=1.5) |
|
|
||||||
| `label_smoothing` | `0.0` | label smoothing (fraction) |
|
| `label_smoothing` | `0.0` | label smoothing (fraction) |
|
||||||
| `nbs` | `64` | nominal batch size |
|
| `nbs` | `64` | nominal batch size |
|
||||||
| `overlap_mask` | `True` | masks should overlap during training (segment train only) |
|
| `overlap_mask` | `True` | masks should overlap during training (segment train only) |
|
||||||
|
@ -296,7 +296,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"text": [
|
"text": [
|
||||||
"Ultralytics YOLOv8.0.57 🚀 Python-3.9.16 torch-1.13.1+cu116 CUDA:0 (Tesla T4, 15102MiB)\n",
|
"Ultralytics YOLOv8.0.57 🚀 Python-3.9.16 torch-1.13.1+cu116 CUDA:0 (Tesla T4, 15102MiB)\n",
|
||||||
"\u001b[34m\u001b[1myolo/engine/trainer: \u001b[0mtask=detect, mode=train, model=yolov8n.pt, data=coco128.yaml, epochs=3, patience=50, batch=16, imgsz=640, save=True, save_period=-1, cache=False, device=None, workers=8, project=None, name=None, exist_ok=False, pretrained=False, optimizer=SGD, verbose=True, seed=0, deterministic=True, single_cls=False, image_weights=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, show=False, save_txt=False, save_conf=False, save_crop=False, hide_labels=False, hide_conf=False, vid_stride=1, line_thickness=3, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, boxes=True, format=torchscript, keras=False, optimize=False, int8=False, dynamic=False, simplify=False, opset=None, workspace=4, nms=False, lr0=0.01, lrf=0.01, momentum=0.937, weight_decay=0.0005, warmup_epochs=3.0, warmup_momentum=0.8, warmup_bias_lr=0.1, box=7.5, cls=0.5, dfl=1.5, fl_gamma=0.0, label_smoothing=0.0, nbs=64, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, flipud=0.0, fliplr=0.5, mosaic=1.0, mixup=0.0, copy_paste=0.0, cfg=None, v5loader=False, tracker=botsort.yaml, save_dir=runs/detect/train\n",
|
"\u001b[34m\u001b[1myolo/engine/trainer: \u001b[0mtask=detect, mode=train, model=yolov8n.pt, data=coco128.yaml, epochs=3, patience=50, batch=16, imgsz=640, save=True, save_period=-1, cache=False, device=None, workers=8, project=None, name=None, exist_ok=False, pretrained=False, optimizer=SGD, verbose=True, seed=0, deterministic=True, single_cls=False, image_weights=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, show=False, save_txt=False, save_conf=False, save_crop=False, hide_labels=False, hide_conf=False, vid_stride=1, line_thickness=3, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, boxes=True, format=torchscript, keras=False, optimize=False, int8=False, dynamic=False, simplify=False, opset=None, workspace=4, nms=False, lr0=0.01, lrf=0.01, momentum=0.937, weight_decay=0.0005, warmup_epochs=3.0, warmup_momentum=0.8, warmup_bias_lr=0.1, box=7.5, cls=0.5, dfl=1.5, label_smoothing=0.0, nbs=64, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, flipud=0.0, fliplr=0.5, mosaic=1.0, mixup=0.0, copy_paste=0.0, cfg=None, v5loader=False, tracker=botsort.yaml, save_dir=runs/detect/train\n",
|
||||||
"\n",
|
"\n",
|
||||||
" from n params module arguments \n",
|
" from n params module arguments \n",
|
||||||
" 0 -1 1 464 ultralytics.nn.modules.Conv [3, 16, 3, 2] \n",
|
" 0 -1 1 464 ultralytics.nn.modules.Conv [3, 16, 3, 2] \n",
|
||||||
|
@ -205,18 +205,33 @@ def test_predict_callback_and_setup():
|
|||||||
|
|
||||||
|
|
||||||
def test_result():
|
def test_result():
|
||||||
|
model = YOLO('yolov8n-pose.pt')
|
||||||
|
res = model([SOURCE, SOURCE])
|
||||||
|
res[0].plot(show_conf=False) # raises warning
|
||||||
|
res[0].plot(conf=True, boxes=False)
|
||||||
|
res[0].plot(pil=True)
|
||||||
|
res[0] = res[0].cpu().numpy()
|
||||||
|
print(res[0].path, res[0].keypoints)
|
||||||
|
|
||||||
model = YOLO('yolov8n-seg.pt')
|
model = YOLO('yolov8n-seg.pt')
|
||||||
res = model([SOURCE, SOURCE])
|
res = model([SOURCE, SOURCE])
|
||||||
res[0].plot(show_conf=False) # raises warning
|
res[0].plot(show_conf=False) # raises warning
|
||||||
res[0].plot(conf=True, boxes=False, masks=True)
|
res[0].plot(conf=True, boxes=False, masks=True)
|
||||||
|
res[0].plot(pil=True)
|
||||||
res[0] = res[0].cpu().numpy()
|
res[0] = res[0].cpu().numpy()
|
||||||
print(res[0].path, res[0].masks.masks)
|
print(res[0].path, res[0].masks.masks)
|
||||||
|
|
||||||
model = YOLO('yolov8n.pt')
|
model = YOLO('yolov8n.pt')
|
||||||
res = model(SOURCE)
|
res = model(SOURCE)
|
||||||
|
res[0].plot(pil=True)
|
||||||
res[0].plot()
|
res[0].plot()
|
||||||
|
res[0] = res[0].cpu().numpy()
|
||||||
print(res[0].path)
|
print(res[0].path)
|
||||||
|
|
||||||
model = YOLO('yolov8n-cls.pt')
|
model = YOLO('yolov8n-cls.pt')
|
||||||
res = model(SOURCE)
|
res = model(SOURCE)
|
||||||
res[0].plot(probs=False)
|
res[0].plot(probs=False)
|
||||||
|
res[0].plot(pil=True)
|
||||||
|
res[0].plot()
|
||||||
|
res[0] = res[0].cpu().numpy()
|
||||||
print(res[0].path)
|
print(res[0].path)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = '8.0.70'
|
__version__ = '8.0.71'
|
||||||
|
|
||||||
from ultralytics.hub import start
|
from ultralytics.hub import start
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load
|
from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load
|
||||||
@ -10,7 +12,19 @@ from .trackers import BOTSORT, BYTETracker
|
|||||||
TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
|
TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
|
||||||
|
|
||||||
|
|
||||||
def on_predict_start(predictor):
|
def on_predict_start(predictor, persist=False):
|
||||||
|
"""
|
||||||
|
Initialize trackers for object tracking during prediction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictor (object): The predictor object to initialize trackers for.
|
||||||
|
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
|
||||||
|
"""
|
||||||
|
if hasattr(predictor, 'trackers') and persist:
|
||||||
|
return
|
||||||
tracker = check_yaml(predictor.args.tracker)
|
tracker = check_yaml(predictor.args.tracker)
|
||||||
cfg = IterableSimpleNamespace(**yaml_load(tracker))
|
cfg = IterableSimpleNamespace(**yaml_load(tracker))
|
||||||
assert cfg.tracker_type in ['bytetrack', 'botsort'], \
|
assert cfg.tracker_type in ['bytetrack', 'botsort'], \
|
||||||
@ -38,6 +52,14 @@ def on_predict_postprocess_end(predictor):
|
|||||||
predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
|
predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
|
||||||
|
|
||||||
|
|
||||||
def register_tracker(model):
|
def register_tracker(model, persist):
|
||||||
model.add_callback('on_predict_start', on_predict_start)
|
"""
|
||||||
|
Register tracking callbacks to the model for object tracking during prediction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (object): The model object to register tracking callbacks for.
|
||||||
|
persist (bool): Whether to persist the trackers if they already exist.
|
||||||
|
|
||||||
|
"""
|
||||||
|
model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
|
||||||
model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
|
model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
|
||||||
|
@ -277,12 +277,13 @@ class BYTETracker:
|
|||||||
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks)
|
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks)
|
||||||
self.lost_stracks.extend(lost_stracks)
|
self.lost_stracks.extend(lost_stracks)
|
||||||
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.removed_stracks)
|
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.removed_stracks)
|
||||||
self.removed_stracks.extend(removed_stracks)
|
|
||||||
self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
|
self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
|
||||||
output = [
|
self.removed_stracks.extend(removed_stracks)
|
||||||
track.tlbr.tolist() + [track.track_id, track.score, track.cls, track.idx] for track in self.tracked_stracks
|
if len(self.removed_stracks) > 1000:
|
||||||
if track.is_activated]
|
self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum
|
||||||
return np.asarray(output, dtype=np.float32)
|
return np.asarray(
|
||||||
|
[x.tlbr.tolist() + [x.track_id, x.score, x.cls, x.idx] for x in self.tracked_stracks if x.is_activated],
|
||||||
|
dtype=np.float32)
|
||||||
|
|
||||||
def get_kalmanfilter(self):
|
def get_kalmanfilter(self):
|
||||||
return KalmanFilterXYAH()
|
return KalmanFilterXYAH()
|
||||||
@ -319,12 +320,16 @@ class BYTETracker:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sub_stracks(tlista, tlistb):
|
def sub_stracks(tlista, tlistb):
|
||||||
|
""" DEPRECATED CODE in https://github.com/ultralytics/ultralytics/pull/1890/
|
||||||
stracks = {t.track_id: t for t in tlista}
|
stracks = {t.track_id: t for t in tlista}
|
||||||
for t in tlistb:
|
for t in tlistb:
|
||||||
tid = t.track_id
|
tid = t.track_id
|
||||||
if stracks.get(tid, 0):
|
if stracks.get(tid, 0):
|
||||||
del stracks[tid]
|
del stracks[tid]
|
||||||
return list(stracks.values())
|
return list(stracks.values())
|
||||||
|
"""
|
||||||
|
track_ids_b = {t.track_id for t in tlistb}
|
||||||
|
return [t for t in tlista if t.track_id not in track_ids_b]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def remove_duplicate_stracks(stracksa, stracksb):
|
def remove_duplicate_stracks(stracksa, stracksb):
|
||||||
|
@ -63,7 +63,7 @@ CLI_HELP_MSG = \
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Define keys for arg type checks
|
# Define keys for arg type checks
|
||||||
CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'fl_gamma'
|
CFG_FLOAT_KEYS = 'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'
|
||||||
CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
|
CFG_FRACTION_KEYS = ('dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr',
|
||||||
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
|
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud',
|
||||||
'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou') # fractional floats limited to 0.0 - 1.0
|
'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou') # fractional floats limited to 0.0 - 1.0
|
||||||
|
@ -90,7 +90,6 @@ cls: 0.5 # cls loss gain (scale with pixels)
|
|||||||
dfl: 1.5 # dfl loss gain
|
dfl: 1.5 # dfl loss gain
|
||||||
pose: 12.0 # pose loss gain
|
pose: 12.0 # pose loss gain
|
||||||
kobj: 1.0 # keypoint obj loss gain
|
kobj: 1.0 # keypoint obj loss gain
|
||||||
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
|
|
||||||
label_smoothing: 0.0 # label smoothing (fraction)
|
label_smoothing: 0.0 # label smoothing (fraction)
|
||||||
nbs: 64 # nominal batch size
|
nbs: 64 # nominal batch size
|
||||||
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
|
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
|
||||||
|
@ -93,7 +93,8 @@ def build_dataloader(cfg, batch, img_path, data_info, stride=32, rect=False, ran
|
|||||||
loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates
|
loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates
|
||||||
generator = torch.Generator()
|
generator = torch.Generator()
|
||||||
generator.manual_seed(6148914691236517205 + RANK)
|
generator.manual_seed(6148914691236517205 + RANK)
|
||||||
return loader(dataset=dataset,
|
return loader(
|
||||||
|
dataset=dataset,
|
||||||
batch_size=batch,
|
batch_size=batch,
|
||||||
shuffle=shuffle and sampler is None,
|
shuffle=shuffle and sampler is None,
|
||||||
num_workers=nw,
|
num_workers=nw,
|
||||||
@ -101,6 +102,7 @@ def build_dataloader(cfg, batch, img_path, data_info, stride=32, rect=False, ran
|
|||||||
pin_memory=PIN_MEMORY,
|
pin_memory=PIN_MEMORY,
|
||||||
collate_fn=getattr(dataset, 'collate_fn', None),
|
collate_fn=getattr(dataset, 'collate_fn', None),
|
||||||
worker_init_fn=seed_worker,
|
worker_init_fn=seed_worker,
|
||||||
|
persistent_workers=(nw > 0) and (loader == DataLoader), # persist workers if using default PyTorch DataLoader
|
||||||
generator=generator), dataset
|
generator=generator), dataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ class YOLODataset(BaseDataset):
|
|||||||
single_cls (bool): if True, single class training is used (default: False).
|
single_cls (bool): if True, single class training is used (default: False).
|
||||||
use_segments (bool): if True, segmentation masks are used as labels (default: False).
|
use_segments (bool): if True, segmentation masks are used as labels (default: False).
|
||||||
use_keypoints (bool): if True, keypoints are used as labels (default: False).
|
use_keypoints (bool): if True, keypoints are used as labels (default: False).
|
||||||
names (list): class names (default: None).
|
names (dict): A dictionary of class names. (default: None).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A PyTorch dataset object that can be used for training an object detection or segmentation model.
|
A PyTorch dataset object that can be used for training an object detection or segmentation model.
|
||||||
|
@ -138,7 +138,7 @@ class Exporter:
|
|||||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.args = get_cfg(cfg, overrides)
|
self.args = get_cfg(cfg, overrides)
|
||||||
self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
@ -379,6 +379,7 @@ class Exporter:
|
|||||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||||
return f, None
|
return f, None
|
||||||
|
|
||||||
|
@try_export
|
||||||
def _export_coreml(self, prefix=colorstr('CoreML:')):
|
def _export_coreml(self, prefix=colorstr('CoreML:')):
|
||||||
# YOLOv8 CoreML export
|
# YOLOv8 CoreML export
|
||||||
check_requirements('coremltools>=6.0')
|
check_requirements('coremltools>=6.0')
|
||||||
|
@ -235,7 +235,8 @@ class YOLO:
|
|||||||
overrides.update(kwargs) # prefer kwargs
|
overrides.update(kwargs) # prefer kwargs
|
||||||
overrides['mode'] = kwargs.get('mode', 'predict')
|
overrides['mode'] = kwargs.get('mode', 'predict')
|
||||||
assert overrides['mode'] in ['track', 'predict']
|
assert overrides['mode'] in ['track', 'predict']
|
||||||
overrides['save'] = kwargs.get('save', False) # not save files by default
|
if not is_cli:
|
||||||
|
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
|
||||||
if not self.predictor:
|
if not self.predictor:
|
||||||
self.task = overrides.get('task') or self.task
|
self.task = overrides.get('task') or self.task
|
||||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
|
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
|
||||||
@ -244,10 +245,23 @@ class YOLO:
|
|||||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
||||||
|
|
||||||
def track(self, source=None, stream=False, **kwargs):
|
def track(self, source=None, stream=False, persist=False, **kwargs):
|
||||||
|
"""
|
||||||
|
Perform object tracking on the input source using the registered trackers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source (str, optional): The input source for object tracking. Can be a file path or a video stream.
|
||||||
|
stream (bool, optional): Whether the input source is a video stream. Defaults to False.
|
||||||
|
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
|
||||||
|
**kwargs: Additional keyword arguments for the tracking process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: The tracking results.
|
||||||
|
|
||||||
|
"""
|
||||||
if not hasattr(self.predictor, 'trackers'):
|
if not hasattr(self.predictor, 'trackers'):
|
||||||
from ultralytics.tracker import register_tracker
|
from ultralytics.tracker import register_tracker
|
||||||
register_tracker(self)
|
register_tracker(self, persist)
|
||||||
# ByteTrack-based method needs low confidence predictions as input
|
# ByteTrack-based method needs low confidence predictions as input
|
||||||
conf = kwargs.get('conf') or 0.1
|
conf = kwargs.get('conf') or 0.1
|
||||||
kwargs['conf'] = conf
|
kwargs['conf'] = conf
|
||||||
|
@ -103,7 +103,7 @@ class BasePredictor:
|
|||||||
self.data_path = None
|
self.data_path = None
|
||||||
self.source_type = None
|
self.source_type = None
|
||||||
self.batch = None
|
self.batch = None
|
||||||
self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
def preprocess(self, img):
|
def preprocess(self, img):
|
||||||
|
@ -70,10 +70,12 @@ class Results(SimpleClass):
|
|||||||
Args:
|
Args:
|
||||||
orig_img (numpy.ndarray): The original image as a numpy array.
|
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||||
path (str): The path to the image file.
|
path (str): The path to the image file.
|
||||||
names (List[str]): A list of class names.
|
names (dict): A dictionary of class names.
|
||||||
boxes (List[List[float]], optional): A list of bounding box coordinates for each detection.
|
boxes (List[List[float]], optional): A list of bounding box coordinates for each detection.
|
||||||
masks (numpy.ndarray, optional): A 3D numpy array of detection masks, where each mask is a binary image.
|
masks (numpy.ndarray, optional): A 3D numpy array of detection masks, where each mask is a binary image.
|
||||||
probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class.
|
probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class.
|
||||||
|
keypoints (List[List[float]], optional): A list of detected keypoints for each object.
|
||||||
|
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
orig_img (numpy.ndarray): The original image as a numpy array.
|
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||||
@ -81,9 +83,12 @@ class Results(SimpleClass):
|
|||||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||||
masks (Masks, optional): A Masks object containing the detection masks.
|
masks (Masks, optional): A Masks object containing the detection masks.
|
||||||
probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class.
|
probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class.
|
||||||
names (List[str]): A list of class names.
|
names (dict): A dictionary of class names.
|
||||||
path (str): The path to the image file.
|
path (str): The path to the image file.
|
||||||
|
keypoints (List[List[float]], optional): A list of detected keypoints for each object.
|
||||||
|
speed (dict): A dictionary of preprocess, inference and postprocess speeds in milliseconds per image.
|
||||||
_keys (tuple): A tuple of attribute names for non-empty attributes.
|
_keys (tuple): A tuple of attribute names for non-empty attributes.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None) -> None:
|
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None) -> None:
|
||||||
@ -93,6 +98,7 @@ class Results(SimpleClass):
|
|||||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||||
self.probs = probs if probs is not None else None
|
self.probs = probs if probs is not None else None
|
||||||
self.keypoints = keypoints if keypoints is not None else None
|
self.keypoints = keypoints if keypoints is not None else None
|
||||||
|
self.speed = {'preprocess': None, 'inference': None, 'postprocess': None} # milliseconds per image
|
||||||
self.names = names
|
self.names = names
|
||||||
self.path = path
|
self.path = path
|
||||||
self._keys = ('boxes', 'masks', 'probs', 'keypoints')
|
self._keys = ('boxes', 'masks', 'probs', 'keypoints')
|
||||||
@ -203,7 +209,7 @@ class Results(SimpleClass):
|
|||||||
keypoints = self.keypoints
|
keypoints = self.keypoints
|
||||||
if pred_masks and show_masks:
|
if pred_masks and show_masks:
|
||||||
if img_gpu is None:
|
if img_gpu is None:
|
||||||
img = LetterBox(pred_masks.shape[1:])(image=annotator.im)
|
img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
|
||||||
img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.masks.device).permute(
|
img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.masks.device).permute(
|
||||||
2, 0, 1).flip(0).contiguous() / 255
|
2, 0, 1).flip(0).contiguous() / 255
|
||||||
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in pred_boxes.cls], im_gpu=img_gpu)
|
annotator.masks(pred_masks.data, colors=[colors(x, True) for x in pred_boxes.cls], im_gpu=img_gpu)
|
||||||
|
@ -142,7 +142,7 @@ class BaseTrainer:
|
|||||||
self.plot_idx = [0, 1, 2]
|
self.plot_idx = [0, 1, 2]
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||||
if RANK in (-1, 0):
|
if RANK in (-1, 0):
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ class BaseValidator:
|
|||||||
if self.args.conf is None:
|
if self.args.conf is None:
|
||||||
self.args.conf = 0.001 # default conf=0.001
|
self.args.conf = 0.001 # default conf=0.001
|
||||||
|
|
||||||
self.callbacks = _callbacks if _callbacks else callbacks.get_default_callbacks()
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def __call__(self, trainer=None, model=None):
|
def __call__(self, trainer=None, model=None):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user