mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
ultralytics 8.0.21
Windows, segments, YAML fixes (#655)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: corey-nm <109536191+corey-nm@users.noreply.github.com>
This commit is contained in:
parent
dc9502c700
commit
6c44ce21d9
4
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
4
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -51,9 +51,9 @@ body:
|
|||||||
label: Environment
|
label: Environment
|
||||||
description: Please specify the software and hardware you used to produce the bug.
|
description: Please specify the software and hardware you used to produce the bug.
|
||||||
placeholder: |
|
placeholder: |
|
||||||
- YOLO: YOLOv8 🚀 v6.0-67-g60e42e1 torch 1.9.0+cu111 CUDA:0 (A100-SXM4-40GB, 40536MiB)
|
- YOLO: Ultralytics YOLOv8.0.21 🚀 Python-3.8.10 torch-1.13.1+cu117 CUDA:0 (A100-SXM-80GB, 81251MiB)
|
||||||
- OS: Ubuntu 20.04
|
- OS: Ubuntu 20.04
|
||||||
- Python: 3.9.0
|
- Python: 3.8.10
|
||||||
validations:
|
validations:
|
||||||
required: false
|
required: false
|
||||||
|
|
||||||
|
@ -35,28 +35,29 @@ def test_train_cls():
|
|||||||
|
|
||||||
# Val checks -----------------------------------------------------------------------------------------------------------
|
# Val checks -----------------------------------------------------------------------------------------------------------
|
||||||
def test_val_detect():
|
def test_val_detect():
|
||||||
run(f'yolo val detect model={MODEL}.pt data=coco8.yaml imgsz=32 epochs=1')
|
run(f'yolo val detect model={MODEL}.pt data=coco8.yaml imgsz=32')
|
||||||
|
|
||||||
|
|
||||||
def test_val_segment():
|
def test_val_segment():
|
||||||
run(f'yolo val segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32 epochs=1')
|
run(f'yolo val segment model={MODEL}-seg.pt data=coco8-seg.yaml imgsz=32')
|
||||||
|
|
||||||
|
|
||||||
def test_val_classify():
|
def test_val_classify():
|
||||||
pass
|
run(f'yolo val classify model={MODEL}-cls.pt data=mnist160 imgsz=32')
|
||||||
|
|
||||||
|
|
||||||
# Predict checks -------------------------------------------------------------------------------------------------------
|
# Predict checks -------------------------------------------------------------------------------------------------------
|
||||||
def test_predict_detect():
|
def test_predict_detect():
|
||||||
run(f"yolo predict detect model={MODEL}.pt source={ROOT / 'assets'} imgsz=320 conf=0.25")
|
run(f"yolo predict detect model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
|
||||||
|
run(f"yolo predict detect model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32")
|
||||||
|
|
||||||
|
|
||||||
def test_predict_segment():
|
def test_predict_segment():
|
||||||
run(f"yolo predict segment model={MODEL}-seg.pt source={ROOT / 'assets'}")
|
run(f"yolo predict segment model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32")
|
||||||
|
|
||||||
|
|
||||||
def test_predict_classify():
|
def test_predict_classify():
|
||||||
pass
|
run(f"yolo predict segment model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32")
|
||||||
|
|
||||||
|
|
||||||
# Export checks --------------------------------------------------------------------------------------------------------
|
# Export checks --------------------------------------------------------------------------------------------------------
|
||||||
|
@ -111,9 +111,11 @@ def test_export_coreml():
|
|||||||
model.export(format='coreml')
|
model.export(format='coreml')
|
||||||
|
|
||||||
|
|
||||||
def test_export_paddle():
|
def test_export_paddle(enabled=False):
|
||||||
model = YOLO(MODEL)
|
# Paddle protobuf requirements conflicting with onnx protobuf requirements
|
||||||
model.export(format='paddle')
|
if enabled:
|
||||||
|
model = YOLO(MODEL)
|
||||||
|
model.export(format='paddle')
|
||||||
|
|
||||||
|
|
||||||
def test_all_model_yamls():
|
def test_all_model_yamls():
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.0.20"
|
__version__ = "8.0.21"
|
||||||
|
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
from ultralytics.yolo.utils import ops
|
from ultralytics.yolo.utils import ops
|
||||||
|
@ -9,8 +9,8 @@ from types import SimpleNamespace
|
|||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
from ultralytics import __version__
|
from ultralytics import __version__
|
||||||
from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, ROOT, USER_CONFIG_DIR,
|
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, ROOT,
|
||||||
IterableSimpleNamespace, colorstr, yaml_load, yaml_print)
|
USER_CONFIG_DIR, IterableSimpleNamespace, colorstr, emojis, yaml_load, yaml_print)
|
||||||
from ultralytics.yolo.utils.checks import check_yolo
|
from ultralytics.yolo.utils.checks import check_yolo
|
||||||
|
|
||||||
CLI_HELP_MSG = \
|
CLI_HELP_MSG = \
|
||||||
@ -69,7 +69,7 @@ def cfg2dict(cfg):
|
|||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace], overrides: Dict = None):
|
def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG, overrides: Dict = None):
|
||||||
"""
|
"""
|
||||||
Load and merge configuration data from a file or dictionary.
|
Load and merge configuration data from a file or dictionary.
|
||||||
|
|
||||||
@ -214,17 +214,19 @@ def entrypoint(debug=False):
|
|||||||
# Mode
|
# Mode
|
||||||
mode = overrides.pop('mode', None)
|
mode = overrides.pop('mode', None)
|
||||||
model = overrides.pop('model', None)
|
model = overrides.pop('model', None)
|
||||||
if mode == 'checks':
|
if mode is None:
|
||||||
|
mode = DEFAULT_CFG.mode or 'predict'
|
||||||
|
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
|
||||||
|
elif mode not in modes:
|
||||||
|
if mode != 'checks':
|
||||||
|
raise ValueError(emojis(f"ERROR ❌ Invalid 'mode={mode}'. Valid modes are {modes}."))
|
||||||
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
|
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
|
||||||
check_yolo()
|
check_yolo()
|
||||||
return
|
return
|
||||||
elif mode is None:
|
|
||||||
mode = DEFAULT_CFG_DICT['mode'] or 'predict'
|
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
|
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
if model is None:
|
if model is None:
|
||||||
model = DEFAULT_CFG_DICT['model'] or 'yolov8n.pt'
|
model = DEFAULT_CFG.model or 'yolov8n.pt'
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
model = YOLO(model)
|
model = YOLO(model)
|
||||||
@ -232,21 +234,21 @@ def entrypoint(debug=False):
|
|||||||
|
|
||||||
# Task
|
# Task
|
||||||
if mode == 'predict' and 'source' not in overrides:
|
if mode == 'predict' and 'source' not in overrides:
|
||||||
overrides['source'] = DEFAULT_CFG_DICT['source'] or ROOT / "assets" if (ROOT / "assets").exists() \
|
overrides['source'] = DEFAULT_CFG.source or ROOT / "assets" if (ROOT / "assets").exists() \
|
||||||
else "https://ultralytics.com/images/bus.jpg"
|
else "https://ultralytics.com/images/bus.jpg"
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
||||||
elif mode in ('train', 'val'):
|
elif mode in ('train', 'val'):
|
||||||
if 'data' not in overrides:
|
if 'data' not in overrides:
|
||||||
overrides['data'] = DEFAULT_CFG_DICT['data'] or 'mnist160' if task == 'classify' \
|
overrides['data'] = DEFAULT_CFG.data or 'mnist160' if task == 'classify' \
|
||||||
else 'coco128-seg.yaml' if task == 'segment' else 'coco128.yaml'
|
else 'coco128-seg.yaml' if task == 'segment' else 'coco128.yaml'
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
|
||||||
elif mode == 'export':
|
elif mode == 'export':
|
||||||
if 'format' not in overrides:
|
if 'format' not in overrides:
|
||||||
overrides['format'] = DEFAULT_CFG_DICT['format'] or 'torchscript'
|
overrides['format'] = DEFAULT_CFG.format or 'torchscript'
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
|
||||||
|
|
||||||
# Run command in python
|
# Run command in python
|
||||||
getattr(model, mode)(verbose=True, **overrides)
|
getattr(model, mode)(**overrides)
|
||||||
|
|
||||||
|
|
||||||
# Special modes --------------------------------------------------------------------------------------------------------
|
# Special modes --------------------------------------------------------------------------------------------------------
|
||||||
|
@ -44,7 +44,8 @@ class LoadStreams:
|
|||||||
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
|
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
|
||||||
assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
|
assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
|
||||||
cap = cv2.VideoCapture(s)
|
cap = cv2.VideoCapture(s)
|
||||||
assert cap.isOpened(), f'{st}Failed to open {s}'
|
if not cap.isOpened():
|
||||||
|
raise ConnectionError(f'{st}Failed to open {s}')
|
||||||
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
|
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
|
||||||
@ -188,8 +189,9 @@ class LoadImages:
|
|||||||
self._new_video(videos[0]) # new video
|
self._new_video(videos[0]) # new video
|
||||||
else:
|
else:
|
||||||
self.cap = None
|
self.cap = None
|
||||||
assert self.nf > 0, f'No images or videos found in {p}. ' \
|
if self.nf == 0:
|
||||||
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
|
raise FileNotFoundError(f'No images or videos found in {p}. '
|
||||||
|
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}')
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
self.count = 0
|
self.count = 0
|
||||||
@ -223,7 +225,8 @@ class LoadImages:
|
|||||||
# Read image
|
# Read image
|
||||||
self.count += 1
|
self.count += 1
|
||||||
im0 = cv2.imread(path) # BGR
|
im0 = cv2.imread(path) # BGR
|
||||||
assert im0 is not None, f'Image Not Found {path}'
|
if im0 is None:
|
||||||
|
raise FileNotFoundError(f'Image Not Found {path}')
|
||||||
s = f'image {self.count}/{self.nf} {path}: '
|
s = f'image {self.count}/{self.nf} {path}: '
|
||||||
|
|
||||||
if self.transforms:
|
if self.transforms:
|
||||||
|
@ -23,14 +23,13 @@ import numpy as np
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
import yaml
|
|
||||||
from PIL import ExifTags, Image, ImageOps
|
from PIL import ExifTags, Image, ImageOps
|
||||||
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ultralytics.yolo.data.utils import check_det_dataset, unzip_file
|
from ultralytics.yolo.data.utils import check_det_dataset, unzip_file
|
||||||
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
|
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
|
||||||
is_kaggle)
|
is_kaggle, yaml_load)
|
||||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
||||||
from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
|
from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
|
||||||
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
||||||
@ -1056,10 +1055,9 @@ class HUBDatasetStats():
|
|||||||
# Initialize class
|
# Initialize class
|
||||||
zipped, data_dir, yaml_path = self._unzip(Path(path))
|
zipped, data_dir, yaml_path = self._unzip(Path(path))
|
||||||
try:
|
try:
|
||||||
with open(check_yaml(yaml_path), errors='ignore') as f:
|
data = yaml_load(check_yaml(yaml_path)) # data dict
|
||||||
data = yaml.safe_load(f) # data dict
|
if zipped:
|
||||||
if zipped:
|
data['path'] = data_dir
|
||||||
data['path'] = data_dir
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception("error/HUB/dataset_stats/yaml_load") from e
|
raise Exception("error/HUB/dataset_stats/yaml_load") from e
|
||||||
|
|
||||||
|
@ -129,7 +129,7 @@ class Exporter:
|
|||||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.args = get_cfg(cfg, overrides)
|
self.args = get_cfg(cfg, overrides)
|
||||||
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
|
@ -61,8 +61,8 @@ class YOLO:
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"'{suffix}' model loading not implemented")
|
raise NotImplementedError(f"'{suffix}' model loading not implemented")
|
||||||
|
|
||||||
def __call__(self, source=None, stream=False, verbose=False, **kwargs):
|
def __call__(self, source=None, stream=False, **kwargs):
|
||||||
return self.predict(source, stream, verbose, **kwargs)
|
return self.predict(source, stream, **kwargs)
|
||||||
|
|
||||||
def _new(self, cfg: str, verbose=True):
|
def _new(self, cfg: str, verbose=True):
|
||||||
"""
|
"""
|
||||||
@ -118,7 +118,7 @@ class YOLO:
|
|||||||
self.model.fuse()
|
self.model.fuse()
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def predict(self, source=None, stream=False, verbose=False, **kwargs):
|
def predict(self, source=None, stream=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Perform prediction using the YOLO model.
|
Perform prediction using the YOLO model.
|
||||||
|
|
||||||
@ -126,7 +126,6 @@ class YOLO:
|
|||||||
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
||||||
Accepts all source types accepted by the YOLO model.
|
Accepts all source types accepted by the YOLO model.
|
||||||
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
||||||
verbose (bool): Whether to print verbose information or not. Defaults to False.
|
|
||||||
**kwargs : Additional keyword arguments passed to the predictor.
|
**kwargs : Additional keyword arguments passed to the predictor.
|
||||||
Check the 'configuration' section in the documentation for all available options.
|
Check the 'configuration' section in the documentation for all available options.
|
||||||
|
|
||||||
@ -143,7 +142,7 @@ class YOLO:
|
|||||||
self.predictor.setup_model(model=self.model)
|
self.predictor.setup_model(model=self.model)
|
||||||
else: # only update args if predictor is already setup
|
else: # only update args if predictor is already setup
|
||||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||||
return self.predictor(source=source, stream=stream, verbose=verbose)
|
return self.predictor(source=source, stream=stream)
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def val(self, data=None, **kwargs):
|
def val(self, data=None, **kwargs):
|
||||||
@ -234,7 +233,8 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
return self.model.names
|
return self.model.names
|
||||||
|
|
||||||
def add_callback(self, event: str, func):
|
@staticmethod
|
||||||
|
def add_callback(event: str, func):
|
||||||
"""
|
"""
|
||||||
Add callback
|
Add callback
|
||||||
"""
|
"""
|
||||||
@ -242,16 +242,8 @@ class YOLO:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reset_ckpt_args(args):
|
def _reset_ckpt_args(args):
|
||||||
args.pop("project", None)
|
for arg in 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', 'save_json', \
|
||||||
args.pop("name", None)
|
'half', 'v5loader':
|
||||||
args.pop("exist_ok", None)
|
args.pop(arg, None)
|
||||||
args.pop("resume", None)
|
|
||||||
args.pop("batch", None)
|
|
||||||
args.pop("epochs", None)
|
|
||||||
args.pop("cache", None)
|
|
||||||
args.pop("save_json", None)
|
|
||||||
args.pop("half", None)
|
|
||||||
args.pop("v5loader", None)
|
|
||||||
|
|
||||||
# set device to '' to prevent from auto DDP usage
|
args["device"] = '' # set device to '' to prevent auto-DDP usage
|
||||||
args["device"] = ''
|
|
||||||
|
@ -88,7 +88,7 @@ class BasePredictor:
|
|||||||
self.vid_path, self.vid_writer = None, None
|
self.vid_path, self.vid_writer = None, None
|
||||||
self.annotator = None
|
self.annotator = None
|
||||||
self.data_path = None
|
self.data_path = None
|
||||||
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
def preprocess(self, img):
|
def preprocess(self, img):
|
||||||
@ -151,19 +151,19 @@ class BasePredictor:
|
|||||||
self.bs = bs
|
self.bs = bs
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def __call__(self, source=None, model=None, verbose=False, stream=False):
|
def __call__(self, source=None, model=None, stream=False):
|
||||||
if stream:
|
if stream:
|
||||||
return self.stream_inference(source, model, verbose)
|
return self.stream_inference(source, model)
|
||||||
else:
|
else:
|
||||||
return list(self.stream_inference(source, model, verbose)) # merge list of Result into one
|
return list(self.stream_inference(source, model)) # merge list of Result into one
|
||||||
|
|
||||||
def predict_cli(self):
|
def predict_cli(self):
|
||||||
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
|
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
|
||||||
gen = self.stream_inference(verbose=True)
|
gen = self.stream_inference()
|
||||||
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
|
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def stream_inference(self, source=None, model=None, verbose=False):
|
def stream_inference(self, source=None, model=None):
|
||||||
self.run_callbacks("on_predict_start")
|
self.run_callbacks("on_predict_start")
|
||||||
|
|
||||||
# setup model
|
# setup model
|
||||||
@ -201,7 +201,7 @@ class BasePredictor:
|
|||||||
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
|
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
|
||||||
p = Path(p)
|
p = Path(p)
|
||||||
|
|
||||||
if verbose or self.args.save or self.args.save_txt or self.args.show:
|
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||||
s += self.write_results(i, self.results, (p, im, im0))
|
s += self.write_results(i, self.results, (p, im, im0))
|
||||||
|
|
||||||
if self.args.show:
|
if self.args.show:
|
||||||
@ -214,11 +214,11 @@ class BasePredictor:
|
|||||||
yield from self.results
|
yield from self.results
|
||||||
|
|
||||||
# Print time (inference-only)
|
# Print time (inference-only)
|
||||||
if verbose:
|
if self.args.verbose:
|
||||||
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
|
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
|
||||||
|
|
||||||
# Print results
|
# Print results
|
||||||
if verbose and self.seen:
|
if self.args.verbose and self.seen:
|
||||||
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
|
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
|
||||||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape '
|
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape '
|
||||||
f'{(1, 3, *self.imgsz)}' % t)
|
f'{(1, 3, *self.imgsz)}' % t)
|
||||||
@ -243,7 +243,7 @@ class BasePredictor:
|
|||||||
if isinstance(source, (str, int, Path)): # int for local usb carame
|
if isinstance(source, (str, int, Path)): # int for local usb carame
|
||||||
source = str(source)
|
source = str(source)
|
||||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||||
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
|
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
|
||||||
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
||||||
screenshot = source.lower().startswith('screen')
|
screenshot = source.lower().startswith('screen')
|
||||||
if is_url and is_file:
|
if is_url and is_file:
|
||||||
|
@ -85,7 +85,6 @@ class BaseTrainer:
|
|||||||
self.console = LOGGER
|
self.console = LOGGER
|
||||||
self.validator = None
|
self.validator = None
|
||||||
self.model = None
|
self.model = None
|
||||||
self.callbacks = defaultdict(list)
|
|
||||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||||
|
|
||||||
# Dirs
|
# Dirs
|
||||||
@ -141,7 +140,7 @@ class BaseTrainer:
|
|||||||
self.plot_idx = [0, 1, 2]
|
self.plot_idx = [0, 1, 2]
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
||||||
if RANK in {0, -1}:
|
if RANK in {0, -1}:
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ class BaseValidator:
|
|||||||
if self.args.conf is None:
|
if self.args.conf is None:
|
||||||
self.args.conf = 0.001 # default conf=0.001
|
self.args.conf = 0.001 # default conf=0.001
|
||||||
|
|
||||||
self.callbacks = defaultdict(list, {k: v for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def __call__(self, trainer=None, model=None):
|
def __call__(self, trainer=None, model=None):
|
||||||
|
@ -5,6 +5,7 @@ import inspect
|
|||||||
import logging.config
|
import logging.config
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
@ -113,12 +114,66 @@ class IterableSimpleNamespace(SimpleNamespace):
|
|||||||
return getattr(self, key, default)
|
return getattr(self, key, default)
|
||||||
|
|
||||||
|
|
||||||
|
def yaml_save(file='data.yaml', data=None):
|
||||||
|
"""
|
||||||
|
Save YAML data to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (str, optional): File name. Default is 'data.yaml'.
|
||||||
|
data (dict, optional): Data to save in YAML format. Default is None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None: Data is saved to the specified file.
|
||||||
|
"""
|
||||||
|
file = Path(file)
|
||||||
|
if not file.parent.exists():
|
||||||
|
# Create parent directories if they don't exist
|
||||||
|
file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(file, 'w') as f:
|
||||||
|
# Dump data to file in YAML format, converting Path objects to strings
|
||||||
|
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
||||||
|
|
||||||
|
|
||||||
|
def yaml_load(file='data.yaml', append_filename=False):
|
||||||
|
"""
|
||||||
|
Load YAML data from a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (str, optional): File name. Default is 'data.yaml'.
|
||||||
|
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: YAML data and file name.
|
||||||
|
"""
|
||||||
|
with open(file, errors='ignore', encoding='utf-8') as f:
|
||||||
|
# Add YAML filename to dict and return
|
||||||
|
s = f.read() # string
|
||||||
|
if not s.isprintable(): # remove special characters
|
||||||
|
s = re.sub(r'[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+', '', s)
|
||||||
|
return {**yaml.safe_load(s), 'yaml_file': str(file)} if append_filename else yaml.safe_load(s)
|
||||||
|
|
||||||
|
|
||||||
|
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
|
||||||
|
"""
|
||||||
|
Pretty prints a yaml file or a yaml-formatted dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
yaml_file: The file path of the yaml file or a yaml-formatted dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
|
||||||
|
dump = yaml.dump(yaml_dict, default_flow_style=False)
|
||||||
|
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
|
||||||
|
|
||||||
|
|
||||||
# Default configuration
|
# Default configuration
|
||||||
with open(DEFAULT_CFG_PATH, errors='ignore') as f:
|
DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH)
|
||||||
DEFAULT_CFG_DICT = yaml.safe_load(f)
|
for k, v in DEFAULT_CFG_DICT.items():
|
||||||
for k, v in DEFAULT_CFG_DICT.items():
|
if isinstance(v, str) and v.lower() == 'none':
|
||||||
if isinstance(v, str) and v.lower() == 'none':
|
DEFAULT_CFG_DICT[k] = None
|
||||||
DEFAULT_CFG_DICT[k] = None
|
|
||||||
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
|
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
|
||||||
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
|
DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
|
||||||
|
|
||||||
@ -393,58 +448,6 @@ def threaded(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def yaml_save(file='data.yaml', data=None):
|
|
||||||
"""
|
|
||||||
Save YAML data to a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (str, optional): File name. Default is 'data.yaml'.
|
|
||||||
data (dict, optional): Data to save in YAML format. Default is None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None: Data is saved to the specified file.
|
|
||||||
"""
|
|
||||||
file = Path(file)
|
|
||||||
if not file.parent.exists():
|
|
||||||
# Create parent directories if they don't exist
|
|
||||||
file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
with open(file, 'w') as f:
|
|
||||||
# Dump data to file in YAML format, converting Path objects to strings
|
|
||||||
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
|
||||||
|
|
||||||
|
|
||||||
def yaml_load(file='data.yaml', append_filename=False):
|
|
||||||
"""
|
|
||||||
Load YAML data from a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (str, optional): File name. Default is 'data.yaml'.
|
|
||||||
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: YAML data and file name.
|
|
||||||
"""
|
|
||||||
with open(file, errors='ignore') as f:
|
|
||||||
# Add YAML filename to dict and return
|
|
||||||
return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f)
|
|
||||||
|
|
||||||
|
|
||||||
def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
|
|
||||||
"""
|
|
||||||
Pretty prints a yaml file or a yaml-formatted dictionary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
yaml_file: The file path of the yaml file or a yaml-formatted dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
|
|
||||||
dump = yaml.dump(yaml_dict, default_flow_style=False)
|
|
||||||
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
|
|
||||||
|
|
||||||
|
|
||||||
def set_sentry():
|
def set_sentry():
|
||||||
"""
|
"""
|
||||||
Initialize the Sentry SDK for error tracking and reporting if pytest is not currently running.
|
Initialize the Sentry SDK for error tracking and reporting if pytest is not currently running.
|
||||||
|
@ -207,9 +207,9 @@ def check_file(file, suffix=''):
|
|||||||
# Search/download file (if necessary) and return path
|
# Search/download file (if necessary) and return path
|
||||||
check_suffix(file, suffix) # optional
|
check_suffix(file, suffix) # optional
|
||||||
file = str(file) # convert to str()
|
file = str(file) # convert to str()
|
||||||
if Path(file).is_file() or not file: # exists
|
if not file or ('://' not in file and Path(file).is_file()): # exists ('://' check required in Windows Python<3.10)
|
||||||
return file
|
return file
|
||||||
elif file.startswith(('http:/', 'https:/')): # download
|
elif file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')): # download
|
||||||
url = file # warning: Pathlib turns :// -> :/
|
url = file # warning: Pathlib turns :// -> :/
|
||||||
file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
|
file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
|
||||||
if Path(file).is_file():
|
if Path(file).is_file():
|
||||||
@ -276,7 +276,7 @@ def git_describe(path=ROOT): # path must be a directory
|
|||||||
try:
|
try:
|
||||||
assert (Path(path) / '.git').is_dir()
|
assert (Path(path) / '.git').is_dir()
|
||||||
return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
||||||
except Exception:
|
except AssertionError:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1
|
|||||||
def download_one(url, dir):
|
def download_one(url, dir):
|
||||||
# Download 1 file
|
# Download 1 file
|
||||||
success = True
|
success = True
|
||||||
if Path(url).is_file():
|
if '://' not in str(url) and Path(url).is_file(): # exists ('://' check required in Windows Python<3.10)
|
||||||
f = Path(url) # filename
|
f = Path(url) # filename
|
||||||
else: # does not exist
|
else: # does not exist
|
||||||
f = dir / Path(url).name
|
f = dir / Path(url).name
|
||||||
|
@ -17,11 +17,8 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
import ultralytics
|
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER
|
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER
|
||||||
from ultralytics.yolo.utils.checks import git_describe
|
from ultralytics.yolo.utils.checks import check_version
|
||||||
|
|
||||||
from .checks import check_version
|
|
||||||
|
|
||||||
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
||||||
RANK = int(os.getenv('RANK', -1))
|
RANK = int(os.getenv('RANK', -1))
|
||||||
@ -60,8 +57,8 @@ def DDP_model(model):
|
|||||||
|
|
||||||
def select_device(device='', batch=0, newline=False):
|
def select_device(device='', batch=0, newline=False):
|
||||||
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
||||||
ver = git_describe() or ultralytics.__version__ # git commit or pip package version
|
from ultralytics import __version__
|
||||||
s = f'Ultralytics YOLOv{ver} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
||||||
device = str(device).lower()
|
device = str(device).lower()
|
||||||
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
|
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
|
||||||
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
||||||
@ -247,6 +244,7 @@ class ModelEMA:
|
|||||||
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
||||||
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
||||||
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||||
|
To disable EMA set the `enabled` attribute to `False`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||||
@ -256,22 +254,25 @@ class ModelEMA:
|
|||||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||||||
for p in self.ema.parameters():
|
for p in self.ema.parameters():
|
||||||
p.requires_grad_(False)
|
p.requires_grad_(False)
|
||||||
|
self.enabled = True
|
||||||
|
|
||||||
def update(self, model):
|
def update(self, model):
|
||||||
# Update EMA parameters
|
# Update EMA parameters
|
||||||
self.updates += 1
|
if self.enabled:
|
||||||
d = self.decay(self.updates)
|
self.updates += 1
|
||||||
|
d = self.decay(self.updates)
|
||||||
|
|
||||||
msd = de_parallel(model).state_dict() # model state_dict
|
msd = de_parallel(model).state_dict() # model state_dict
|
||||||
for k, v in self.ema.state_dict().items():
|
for k, v in self.ema.state_dict().items():
|
||||||
if v.dtype.is_floating_point: # true for FP16 and FP32
|
if v.dtype.is_floating_point: # true for FP16 and FP32
|
||||||
v *= d
|
v *= d
|
||||||
v += (1 - d) * msd[k].detach()
|
v += (1 - d) * msd[k].detach()
|
||||||
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
|
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
|
||||||
|
|
||||||
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
||||||
# Update EMA attributes
|
# Update EMA attributes
|
||||||
copy_attr(self.ema, model, include, exclude)
|
if self.enabled:
|
||||||
|
copy_attr(self.ema, model, include, exclude)
|
||||||
|
|
||||||
|
|
||||||
def strip_optimizer(f='best.pt', s=''):
|
def strip_optimizer(f='best.pt', s=''):
|
||||||
@ -285,8 +286,8 @@ def strip_optimizer(f='best.pt', s=''):
|
|||||||
strip_optimizer(f)
|
strip_optimizer(f)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
f (str): file path to model state to strip the optimizer from. Default is 'best.pt'.
|
f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
|
||||||
s (str): file path to save the model with stripped optimizer to. Default is ''. If not provided, the original file will be overwritten.
|
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
@ -364,12 +365,12 @@ class EarlyStopping:
|
|||||||
Early stopping class that stops training when a specified number of epochs have passed without improvement.
|
Early stopping class that stops training when a specified number of epochs have passed without improvement.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, patience=30):
|
def __init__(self, patience=50):
|
||||||
"""
|
"""
|
||||||
Initialize early stopping object
|
Initialize early stopping object
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping. Default is 30.
|
patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
|
||||||
"""
|
"""
|
||||||
self.best_fitness = 0.0 # i.e. mAP
|
self.best_fitness = 0.0 # i.e. mAP
|
||||||
self.best_epoch = 0
|
self.best_epoch = 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user