mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-07-07 13:44:23 +08:00
ultralytics 8.0.12
- Hydra removal (#506)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Pronoy Mandal <lukex9442@gmail.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
6eec39162a
commit
c5fccc3fc4
3
.github/workflows/ci.yaml
vendored
3
.github/workflows/ci.yaml
vendored
@ -85,6 +85,7 @@ jobs:
|
|||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
run: |
|
run: |
|
||||||
yolo task=detect mode=train data=coco8.yaml model=yolov8n.yaml epochs=1 imgsz=32
|
yolo task=detect mode=train data=coco8.yaml model=yolov8n.yaml epochs=1 imgsz=32
|
||||||
|
yolo task=detect mode=train data=coco8.yaml model=yolov8n.pt epochs=1 imgsz=32
|
||||||
yolo task=detect mode=val data=coco8.yaml model=runs/detect/train/weights/last.pt imgsz=32
|
yolo task=detect mode=val data=coco8.yaml model=runs/detect/train/weights/last.pt imgsz=32
|
||||||
yolo task=detect mode=predict model=runs/detect/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
|
yolo task=detect mode=predict model=runs/detect/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
|
||||||
yolo mode=export model=runs/detect/train/weights/last.pt imgsz=32 format=torchscript
|
yolo mode=export model=runs/detect/train/weights/last.pt imgsz=32 format=torchscript
|
||||||
@ -92,6 +93,7 @@ jobs:
|
|||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
run: |
|
run: |
|
||||||
yolo task=segment mode=train data=coco8-seg.yaml model=yolov8n-seg.yaml epochs=1 imgsz=32
|
yolo task=segment mode=train data=coco8-seg.yaml model=yolov8n-seg.yaml epochs=1 imgsz=32
|
||||||
|
yolo task=segment mode=train data=coco8-seg.yaml model=yolov8n-seg.pt epochs=1 imgsz=32
|
||||||
yolo task=segment mode=val data=coco8-seg.yaml model=runs/segment/train/weights/last.pt imgsz=32
|
yolo task=segment mode=val data=coco8-seg.yaml model=runs/segment/train/weights/last.pt imgsz=32
|
||||||
yolo task=segment mode=predict model=runs/segment/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
|
yolo task=segment mode=predict model=runs/segment/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
|
||||||
yolo mode=export model=runs/segment/train/weights/last.pt imgsz=32 format=torchscript
|
yolo mode=export model=runs/segment/train/weights/last.pt imgsz=32 format=torchscript
|
||||||
@ -99,6 +101,7 @@ jobs:
|
|||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
run: |
|
run: |
|
||||||
yolo task=classify mode=train data=mnist160 model=yolov8n-cls.yaml epochs=1 imgsz=32
|
yolo task=classify mode=train data=mnist160 model=yolov8n-cls.yaml epochs=1 imgsz=32
|
||||||
|
yolo task=classify mode=train data=mnist160 model=yolov8n-cls.pt epochs=1 imgsz=32
|
||||||
yolo task=classify mode=val data=mnist160 model=runs/classify/train/weights/last.pt imgsz=32
|
yolo task=classify mode=val data=mnist160 model=runs/classify/train/weights/last.pt imgsz=32
|
||||||
yolo task=classify mode=predict model=runs/classify/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
|
yolo task=classify mode=predict model=runs/classify/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
|
||||||
yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript
|
yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -136,6 +136,7 @@ wandb/
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
# Neural Network weights -----------------------------------------------------------------------------------------------
|
# Neural Network weights -----------------------------------------------------------------------------------------------
|
||||||
|
weights/
|
||||||
*.weights
|
*.weights
|
||||||
*.pt
|
*.pt
|
||||||
*.pb
|
*.pb
|
||||||
|
@ -10,7 +10,7 @@ ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Aria
|
|||||||
|
|
||||||
# Remove torch nightly and install torch stable
|
# Remove torch nightly and install torch stable
|
||||||
RUN rm -rf /opt/pytorch # remove 1.2GB dir
|
RUN rm -rf /opt/pytorch # remove 1.2GB dir
|
||||||
RUN pip uninstall -y torchtext torch torchvision
|
RUN pip uninstall -y torchtext pillow torch torchvision
|
||||||
RUN pip install --no-cache torch torchvision
|
RUN pip install --no-cache torch torchvision
|
||||||
|
|
||||||
# Install linux packages
|
# Install linux packages
|
||||||
|
@ -6,9 +6,9 @@ Inference or prediction of a task returns a list of `Results` objects. Alternati
|
|||||||
inputs = [img, img] # list of np arrays
|
inputs = [img, img] # list of np arrays
|
||||||
results = model(inputs) # List of Results objects
|
results = model(inputs) # List of Results objects
|
||||||
for result in results:
|
for result in results:
|
||||||
boxes = results.boxes # Boxes object for bbox outputs
|
boxes = result.boxes # Boxes object for bbox outputs
|
||||||
masks = results.masks # Masks object for segmenation masks outputs
|
masks = result.masks # Masks object for segmenation masks outputs
|
||||||
probs = results.probs # Class probabilities for classification outputs
|
probs = result.probs # Class probabilities for classification outputs
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
=== "Getting a Generator"
|
=== "Getting a Generator"
|
||||||
@ -16,9 +16,9 @@ Inference or prediction of a task returns a list of `Results` objects. Alternati
|
|||||||
inputs = [img, img] # list of np arrays
|
inputs = [img, img] # list of np arrays
|
||||||
results = model(inputs, stream=True) # Generator of Results objects
|
results = model(inputs, stream=True) # Generator of Results objects
|
||||||
for result in results:
|
for result in results:
|
||||||
boxes = results.boxes # Boxes object for bbox outputs
|
boxes = result.boxes # Boxes object for bbox outputs
|
||||||
masks = results.masks # Masks object for segmenation masks outputs
|
masks = result.masks # Masks object for segmenation masks outputs
|
||||||
probs = results.probs # Class probabilities for classification outputs
|
probs = result.probs # Class probabilities for classification outputs
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -69,4 +69,4 @@ results.masks.segments # bounding coordinates of masks, List[segment] * N
|
|||||||
results.probs # cls prob, (num_class, )
|
results.probs # cls prob, (num_class, )
|
||||||
```
|
```
|
||||||
|
|
||||||
Class reference documentation for `Results` module and its components can be found [here](reference/results.md)
|
Class reference documentation for `Results` module and its components can be found [here](reference/results.md)
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
# Usage: pip install -r requirements.txt
|
# Usage: pip install -r requirements.txt
|
||||||
|
|
||||||
# Base ----------------------------------------
|
# Base ----------------------------------------
|
||||||
hydra-core>=1.2.0
|
|
||||||
matplotlib>=3.2.2
|
matplotlib>=3.2.2
|
||||||
numpy>=1.18.5
|
numpy>=1.18.5
|
||||||
opencv-python>=4.1.1
|
opencv-python>=4.1.1
|
||||||
|
3
setup.py
3
setup.py
@ -51,4 +51,5 @@ setup(
|
|||||||
"Operating System :: MacOS", "Operating System :: Microsoft :: Windows"],
|
"Operating System :: MacOS", "Operating System :: Microsoft :: Windows"],
|
||||||
keywords="machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics",
|
keywords="machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics",
|
||||||
entry_points={
|
entry_points={
|
||||||
'console_scripts': ['yolo = ultralytics.yolo.cli:entrypoint', 'ultralytics = ultralytics.yolo.cli:entrypoint']})
|
'console_scripts':
|
||||||
|
['yolo = ultralytics.yolo.configs:entrypoint', 'ultralytics = ultralytics.yolo.configs:entrypoint']})
|
||||||
|
@ -3,13 +3,13 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from ultralytics.yolo.configs import get_config
|
from ultralytics.yolo.configs import get_config
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, SETTINGS
|
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, ROOT, SETTINGS
|
||||||
from ultralytics.yolo.v8 import classify, detect, segment
|
from ultralytics.yolo.v8 import classify, detect, segment
|
||||||
|
|
||||||
CFG_DET = 'yolov8n.yaml'
|
CFG_DET = 'yolov8n.yaml'
|
||||||
CFG_SEG = 'yolov8n-seg.yaml'
|
CFG_SEG = 'yolov8n-seg.yaml'
|
||||||
CFG_CLS = 'squeezenet1_0'
|
CFG_CLS = 'squeezenet1_0'
|
||||||
CFG = get_config(DEFAULT_CONFIG)
|
CFG = get_config(DEFAULT_CFG_PATH)
|
||||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
|
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
|
||||||
SOURCE = ROOT / "assets"
|
SOURCE = ROOT / "assets"
|
||||||
|
|
||||||
|
@ -49,6 +49,8 @@ def test_predict_img():
|
|||||||
assert len(output) == 1, "predict test failed"
|
assert len(output) == 1, "predict test failed"
|
||||||
output = model(source=[img, img], save=True, save_txt=True) # batch
|
output = model(source=[img, img], save=True, save_txt=True) # batch
|
||||||
assert len(output) == 2, "predict test failed"
|
assert len(output) == 2, "predict test failed"
|
||||||
|
output = model(source=[img, img], save=True, stream=True) # stream
|
||||||
|
assert len(list(output)) == 2, "predict test failed"
|
||||||
tens = torch.zeros(320, 640, 3)
|
tens = torch.zeros(320, 640, 3)
|
||||||
output = model(tens.numpy())
|
output = model(tens.numpy())
|
||||||
assert len(output) == 1, "predict test failed"
|
assert len(output) == 1, "predict test failed"
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.0.11"
|
__version__ = "8.0.12"
|
||||||
|
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
from ultralytics.yolo.utils import ops
|
from ultralytics.yolo.utils import ops
|
||||||
|
@ -7,7 +7,7 @@ import time
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, LOGGER, RANK, SETTINGS, TryExcept, colorstr, emojis
|
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, TryExcept, colorstr, emojis
|
||||||
|
|
||||||
PREFIX = colorstr('Ultralytics: ')
|
PREFIX = colorstr('Ultralytics: ')
|
||||||
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
||||||
@ -143,7 +143,7 @@ def sync_analytics(cfg, all_keys=False, enabled=False):
|
|||||||
if SETTINGS['sync'] and RANK in {-1, 0} and enabled:
|
if SETTINGS['sync'] and RANK in {-1, 0} and enabled:
|
||||||
cfg = dict(cfg) # convert type from DictConfig to dict
|
cfg = dict(cfg) # convert type from DictConfig to dict
|
||||||
if not all_keys:
|
if not all_keys:
|
||||||
cfg = {k: v for k, v in cfg.items() if v != DEFAULT_CONFIG_DICT.get(k, None)} # retain non-default values
|
cfg = {k: v for k, v in cfg.items() if v != DEFAULT_CFG_DICT.get(k, None)} # retain non-default values
|
||||||
cfg['uuid'] = SETTINGS['uuid'] # add the device UUID to the configuration data
|
cfg['uuid'] = SETTINGS['uuid'] # add the device UUID to the configuration data
|
||||||
|
|
||||||
# Send a request to the HUB API to sync analytics
|
# Send a request to the HUB API to sync analytics
|
||||||
|
@ -10,7 +10,7 @@ import torch.nn as nn
|
|||||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
||||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||||
GhostBottleneck, GhostConv, Segment)
|
GhostBottleneck, GhostConv, Segment)
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER, colorstr, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_yaml
|
from ultralytics.yolo.utils.checks import check_yaml
|
||||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
|
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
|
||||||
model_info, scale_img, time_sync)
|
model_info, scale_img, time_sync)
|
||||||
@ -113,7 +113,7 @@ class BaseModel(nn.Module):
|
|||||||
thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
|
thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
||||||
"""
|
"""
|
||||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||||
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
||||||
@ -321,11 +321,11 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||||||
model = Ensemble()
|
model = Ensemble()
|
||||||
for w in weights if isinstance(weights, list) else [weights]:
|
for w in weights if isinstance(weights, list) else [weights]:
|
||||||
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
||||||
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||||
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||||
|
|
||||||
# Model compatibility updates
|
# Model compatibility updates
|
||||||
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # attach args to model
|
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||||
ckpt.pt_path = weights # attach *.pt file path to model
|
ckpt.pt_path = weights # attach *.pt file path to model
|
||||||
if not hasattr(ckpt, 'stride'):
|
if not hasattr(ckpt, 'stride'):
|
||||||
ckpt.stride = torch.tensor([32.])
|
ckpt.stride = torch.tensor([32.])
|
||||||
@ -359,11 +359,11 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
|||||||
from ultralytics.yolo.utils.downloads import attempt_download
|
from ultralytics.yolo.utils.downloads import attempt_download
|
||||||
|
|
||||||
ckpt = torch.load(attempt_download(weight), map_location='cpu') # load
|
ckpt = torch.load(attempt_download(weight), map_location='cpu') # load
|
||||||
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||||
|
|
||||||
# Model compatibility updates
|
# Model compatibility updates
|
||||||
model.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # attach args to model
|
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||||
model.pt_path = weight # attach *.pt file path to model
|
model.pt_path = weight # attach *.pt file path to model
|
||||||
if not hasattr(model, 'stride'):
|
if not hasattr(model, 'stride'):
|
||||||
model.stride = torch.tensor([32.])
|
model.stride = torch.tensor([32.])
|
||||||
|
@ -1,156 +0,0 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import re
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from ultralytics import __version__, yolo
|
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, PREFIX, checks, print_settings, yaml_load
|
|
||||||
|
|
||||||
DIR = Path(__file__).parent
|
|
||||||
|
|
||||||
CLI_HELP_MSG = \
|
|
||||||
"""
|
|
||||||
YOLOv8 CLI Usage examples:
|
|
||||||
|
|
||||||
1. Install the ultralytics package:
|
|
||||||
|
|
||||||
pip install ultralytics
|
|
||||||
|
|
||||||
2. Train, Val, Predict and Export using 'yolo' commands:
|
|
||||||
|
|
||||||
yolo TASK MODE ARGS
|
|
||||||
|
|
||||||
Where TASK (optional) is one of [detect, segment, classify]
|
|
||||||
MODE (required) is one of [train, val, predict, export]
|
|
||||||
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
|
||||||
For a full list of available ARGS see https://docs.ultralytics.com/config.
|
|
||||||
|
|
||||||
Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
|
||||||
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
|
||||||
|
|
||||||
Predict a YouTube video using a pretrained segmentation model at image size 320:
|
|
||||||
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
|
|
||||||
|
|
||||||
Validate a pretrained detection model at batch-size 1 and image size 640:
|
|
||||||
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
|
||||||
|
|
||||||
Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
|
||||||
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
|
||||||
|
|
||||||
3. Run special commands:
|
|
||||||
|
|
||||||
yolo help
|
|
||||||
yolo checks
|
|
||||||
yolo version
|
|
||||||
yolo settings
|
|
||||||
yolo copy-config
|
|
||||||
|
|
||||||
Docs: https://docs.ultralytics.com/cli
|
|
||||||
Community: https://community.ultralytics.com
|
|
||||||
GitHub: https://github.com/ultralytics/ultralytics
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def cli(cfg):
|
|
||||||
"""
|
|
||||||
Run a specified task and mode with the given configuration.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg (DictConfig): Configuration for the task and mode.
|
|
||||||
"""
|
|
||||||
# LOGGER.info(f"{colorstr(f'Ultralytics YOLO v{ultralytics.__version__}')}")
|
|
||||||
from ultralytics.yolo.configs import get_config
|
|
||||||
|
|
||||||
if cfg.cfg:
|
|
||||||
LOGGER.info(f"{PREFIX}Overriding default config with {cfg.cfg}")
|
|
||||||
cfg = get_config(cfg.cfg)
|
|
||||||
task, mode = cfg.task.lower(), cfg.mode.lower()
|
|
||||||
|
|
||||||
# Mapping from task to module
|
|
||||||
tasks = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}
|
|
||||||
module = tasks.get(task)
|
|
||||||
if not module:
|
|
||||||
raise SyntaxError(f"yolo task={task} is invalid. Valid tasks are: {', '.join(tasks.keys())}\n{CLI_HELP_MSG}")
|
|
||||||
|
|
||||||
# Mapping from mode to function
|
|
||||||
modes = {"train": module.train, "val": module.val, "predict": module.predict, "export": yolo.engine.exporter.export}
|
|
||||||
func = modes.get(mode)
|
|
||||||
if not func:
|
|
||||||
raise SyntaxError(f"yolo mode={mode} is invalid. Valid modes are: {', '.join(modes.keys())}\n{CLI_HELP_MSG}")
|
|
||||||
|
|
||||||
func(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def entrypoint():
|
|
||||||
"""
|
|
||||||
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
|
||||||
to the package. It's a combination of argparse and hydra.
|
|
||||||
|
|
||||||
This function allows for:
|
|
||||||
- passing mandatory YOLO args as a list of strings
|
|
||||||
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
|
|
||||||
- specifying the mode, either 'train', 'val', 'test', or 'predict'
|
|
||||||
- running special modes like 'checks'
|
|
||||||
- passing overrides to the package's configuration
|
|
||||||
|
|
||||||
It uses the package's default config and initializes it using the passed overrides.
|
|
||||||
Then it calls the CLI function with the composed config
|
|
||||||
"""
|
|
||||||
if len(sys.argv) == 1: # no arguments passed
|
|
||||||
LOGGER.info(CLI_HELP_MSG)
|
|
||||||
return
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='YOLO parser')
|
|
||||||
parser.add_argument('args', type=str, nargs='+', help='YOLO args')
|
|
||||||
args = parser.parse_args().args
|
|
||||||
args = re.sub(r'\s*=\s*', '=', ' '.join(args)).split(' ') # remove whitespaces around = sign
|
|
||||||
|
|
||||||
tasks = 'detect', 'segment', 'classify'
|
|
||||||
modes = 'train', 'val', 'predict', 'export'
|
|
||||||
special_modes = {
|
|
||||||
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
|
||||||
'checks': checks.check_yolo,
|
|
||||||
'version': lambda: LOGGER.info(__version__),
|
|
||||||
'settings': print_settings,
|
|
||||||
'copy-config': copy_default_config}
|
|
||||||
|
|
||||||
overrides = [] # basic overrides, i.e. imgsz=320
|
|
||||||
defaults = yaml_load(DEFAULT_CONFIG)
|
|
||||||
for a in args:
|
|
||||||
if '=' in a:
|
|
||||||
overrides.append(a)
|
|
||||||
elif a in tasks:
|
|
||||||
overrides.append(f'task={a}')
|
|
||||||
elif a in modes:
|
|
||||||
overrides.append(f'mode={a}')
|
|
||||||
elif a in special_modes:
|
|
||||||
special_modes[a]()
|
|
||||||
return
|
|
||||||
elif a in defaults and defaults[a] is False:
|
|
||||||
overrides.append(f'{a}=True') # auto-True for default False args, i.e. yolo show
|
|
||||||
elif a in defaults:
|
|
||||||
raise SyntaxError(f"'{a}' is a valid YOLO argument but is missing an '=' sign to set its value, "
|
|
||||||
f"i.e. try '{a}={defaults[a]}'"
|
|
||||||
f"\n{CLI_HELP_MSG}")
|
|
||||||
else:
|
|
||||||
raise SyntaxError(
|
|
||||||
f"'{a}' is not a valid YOLO argument. For a full list of valid arguments see "
|
|
||||||
f"https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/configs/default.yaml"
|
|
||||||
f"\n{CLI_HELP_MSG}")
|
|
||||||
|
|
||||||
from hydra import compose, initialize
|
|
||||||
|
|
||||||
with initialize(version_base=None, config_path=str(DEFAULT_CONFIG.parent.relative_to(DIR)), job_name="YOLO"):
|
|
||||||
cfg = compose(config_name=DEFAULT_CONFIG.name, overrides=overrides)
|
|
||||||
cli(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
# Special modes --------------------------------------------------------------------------------------------------------
|
|
||||||
def copy_default_config():
|
|
||||||
new_file = Path.cwd() / DEFAULT_CONFIG.name.replace('.yaml', '_copy.yaml')
|
|
||||||
shutil.copy2(DEFAULT_CONFIG, new_file)
|
|
||||||
LOGGER.info(f"{PREFIX}{DEFAULT_CONFIG} copied to {new_file}\n"
|
|
||||||
f"Usage for running YOLO with this new custom config:\nyolo cfg={new_file} args...")
|
|
@ -1,36 +1,221 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
import argparse
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
from difflib import get_close_matches
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from ultralytics import __version__, yolo
|
||||||
|
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, PREFIX, checks, colorstr, print_settings, yaml_load
|
||||||
|
|
||||||
from ultralytics.yolo.configs.hydra_patch import check_config_mismatch
|
DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
CLI_HELP_MSG = \
|
||||||
|
"""
|
||||||
|
YOLOv8 CLI Usage examples:
|
||||||
|
|
||||||
|
1. Install the ultralytics package:
|
||||||
|
|
||||||
|
pip install ultralytics
|
||||||
|
|
||||||
|
2. Train, Val, Predict and Export using 'yolo' commands:
|
||||||
|
|
||||||
|
yolo TASK MODE ARGS
|
||||||
|
|
||||||
|
Where TASK (optional) is one of [detect, segment, classify]
|
||||||
|
MODE (required) is one of [train, val, predict, export]
|
||||||
|
ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults.
|
||||||
|
For a full list of available ARGS see https://docs.ultralytics.com/config.
|
||||||
|
|
||||||
|
Train a detection model for 10 epochs with an initial learning_rate of 0.01
|
||||||
|
yolo detect train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
|
||||||
|
|
||||||
|
Predict a YouTube video using a pretrained segmentation model at image size 320:
|
||||||
|
yolo segment predict model=yolov8n-seg.pt source=https://youtu.be/Zgi9g1ksQHc imgsz=320
|
||||||
|
|
||||||
|
Validate a pretrained detection model at batch-size 1 and image size 640:
|
||||||
|
yolo detect val model=yolov8n.pt data=coco128.yaml batch=1 imgsz=640
|
||||||
|
|
||||||
|
Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||||
|
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
||||||
|
|
||||||
|
3. Run special commands:
|
||||||
|
|
||||||
|
yolo help
|
||||||
|
yolo checks
|
||||||
|
yolo version
|
||||||
|
yolo settings
|
||||||
|
yolo copy-config
|
||||||
|
|
||||||
|
Docs: https://docs.ultralytics.com/cli
|
||||||
|
Community: https://community.ultralytics.com
|
||||||
|
GitHub: https://github.com/ultralytics/ultralytics
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_config(config: Union[str, Path, DictConfig], overrides: Union[str, Dict] = None):
|
def cfg2dict(cfg):
|
||||||
|
"""
|
||||||
|
Convert a configuration object to a dictionary.
|
||||||
|
|
||||||
|
This function converts a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
cfg (str) or (Path) or (SimpleNamespace): Configuration object to be converted to a dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
cfg (dict): Configuration object in dictionary format.
|
||||||
|
"""
|
||||||
|
if isinstance(cfg, (str, Path)):
|
||||||
|
cfg = yaml_load(cfg) # load dict
|
||||||
|
elif isinstance(cfg, SimpleNamespace):
|
||||||
|
cfg = vars(cfg) # convert to dict
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(config: Union[str, Path, Dict, SimpleNamespace], overrides: Dict = None):
|
||||||
"""
|
"""
|
||||||
Load and merge configuration data from a file or dictionary.
|
Load and merge configuration data from a file or dictionary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (str) or (Path) or (DictConfig): Configuration data in the form of a file name or a DictConfig object.
|
config (str) or (Path) or (Dict) or (SimpleNamespace): Configuration data.
|
||||||
overrides (str) or(Dict), optional: Overrides in the form of a file name or a dictionary. Default is None.
|
overrides (str) or (Dict), optional: Overrides in the form of a file name or a dictionary. Default is None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
OmegaConf.Namespace: Training arguments namespace.
|
(SimpleNamespace): Training arguments namespace.
|
||||||
"""
|
"""
|
||||||
if overrides is None:
|
config = cfg2dict(config)
|
||||||
overrides = {}
|
|
||||||
if isinstance(config, (str, Path)):
|
|
||||||
config = OmegaConf.load(config)
|
|
||||||
elif isinstance(config, Dict):
|
|
||||||
config = OmegaConf.create(config)
|
|
||||||
# override
|
|
||||||
if isinstance(overrides, str):
|
|
||||||
overrides = OmegaConf.load(overrides)
|
|
||||||
elif isinstance(overrides, Dict):
|
|
||||||
overrides = OmegaConf.create(overrides)
|
|
||||||
|
|
||||||
check_config_mismatch(dict(overrides).keys(), dict(config).keys())
|
# Merge overrides
|
||||||
|
if overrides:
|
||||||
|
overrides = cfg2dict(overrides)
|
||||||
|
check_config_mismatch(config, overrides)
|
||||||
|
config = {**config, **overrides} # merge config and overrides dicts (prefer overrides)
|
||||||
|
|
||||||
return OmegaConf.merge(config, overrides)
|
# Return instance
|
||||||
|
return SimpleNamespace(**config)
|
||||||
|
|
||||||
|
|
||||||
|
def check_config_mismatch(base: Dict, custom: Dict):
|
||||||
|
"""
|
||||||
|
This function checks for any mismatched keys between a custom configuration list and a base configuration list.
|
||||||
|
If any mismatched keys are found, the function prints out similar keys from the base list and exits the program.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- custom (Dict): a dictionary of custom configuration options
|
||||||
|
- base (Dict): a dictionary of base configuration options
|
||||||
|
"""
|
||||||
|
base, custom = (set(x.keys()) for x in (base, custom))
|
||||||
|
mismatched = [x for x in custom if x not in base]
|
||||||
|
for option in mismatched:
|
||||||
|
LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, base, 3, 0.6)}")
|
||||||
|
if mismatched:
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
|
||||||
|
def entrypoint(debug=True):
|
||||||
|
"""
|
||||||
|
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
||||||
|
to the package.
|
||||||
|
|
||||||
|
This function allows for:
|
||||||
|
- passing mandatory YOLO args as a list of strings
|
||||||
|
- specifying the task to be performed, either 'detect', 'segment' or 'classify'
|
||||||
|
- specifying the mode, either 'train', 'val', 'test', or 'predict'
|
||||||
|
- running special modes like 'checks'
|
||||||
|
- passing overrides to the package's configuration
|
||||||
|
|
||||||
|
It uses the package's default config and initializes it using the passed overrides.
|
||||||
|
Then it calls the CLI function with the composed config
|
||||||
|
"""
|
||||||
|
if debug:
|
||||||
|
args = ['train', 'predict', 'model=yolov8n.pt'] # for testing
|
||||||
|
else:
|
||||||
|
if len(sys.argv) == 1: # no arguments passed
|
||||||
|
LOGGER.info(CLI_HELP_MSG)
|
||||||
|
return
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='YOLO parser')
|
||||||
|
parser.add_argument('args', type=str, nargs='+', help='YOLO args')
|
||||||
|
args = parser.parse_args().args
|
||||||
|
args = re.sub(r'\s*=\s*', '=', ' '.join(args)).split(' ') # remove whitespaces around = sign
|
||||||
|
|
||||||
|
tasks = 'detect', 'segment', 'classify'
|
||||||
|
modes = 'train', 'val', 'predict', 'export'
|
||||||
|
special_modes = {
|
||||||
|
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
||||||
|
'checks': checks.check_yolo,
|
||||||
|
'version': lambda: LOGGER.info(__version__),
|
||||||
|
'settings': print_settings,
|
||||||
|
'copy-config': copy_default_config}
|
||||||
|
|
||||||
|
overrides = {} # basic overrides, i.e. imgsz=320
|
||||||
|
defaults = yaml_load(DEFAULT_CFG_PATH)
|
||||||
|
for a in args:
|
||||||
|
if '=' in a:
|
||||||
|
if a.startswith('cfg='): # custom.yaml passed
|
||||||
|
custom_config = Path(a.split('=')[-1])
|
||||||
|
LOGGER.info(f"{PREFIX}Overriding {DEFAULT_CFG_PATH} with {custom_config}")
|
||||||
|
overrides = {k: v for k, v in yaml_load(custom_config).items() if k not in {'cfg'}}
|
||||||
|
else:
|
||||||
|
k, v = a.split('=')
|
||||||
|
try:
|
||||||
|
if k == 'device': # special DDP handling, i.e. device='0,1,2,3'
|
||||||
|
v = v.replace('[', '').replace(']', '') # handle device=[0,1,2,3]
|
||||||
|
v = v.replace(" ", "").replace('') # handle device=[0, 1, 2, 3]
|
||||||
|
v = v.replace('\\', '') # handle device=\'0,1,2,3\'
|
||||||
|
overrides[k] = v
|
||||||
|
else:
|
||||||
|
overrides[k] = eval(v) # convert strings to integers, floats, bools, etc.
|
||||||
|
except (NameError, SyntaxError):
|
||||||
|
overrides[k] = v
|
||||||
|
elif a in tasks:
|
||||||
|
overrides['task'] = a
|
||||||
|
elif a in modes:
|
||||||
|
overrides['mode'] = a
|
||||||
|
elif a in special_modes:
|
||||||
|
special_modes[a]()
|
||||||
|
return
|
||||||
|
elif a in defaults and defaults[a] is False:
|
||||||
|
overrides[a] = True # auto-True for default False args, i.e. 'yolo show' sets show=True
|
||||||
|
elif a in defaults:
|
||||||
|
raise SyntaxError(f"'{a}' is a valid YOLO argument but is missing an '=' sign to set its value, "
|
||||||
|
f"i.e. try '{a}={defaults[a]}'"
|
||||||
|
f"\n{CLI_HELP_MSG}")
|
||||||
|
else:
|
||||||
|
raise SyntaxError(
|
||||||
|
f"'{a}' is not a valid YOLO argument. For a full list of valid arguments see "
|
||||||
|
f"https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/configs/default.yaml"
|
||||||
|
f"\n{CLI_HELP_MSG}")
|
||||||
|
|
||||||
|
cfg = get_config(defaults, overrides) # create CFG instance
|
||||||
|
|
||||||
|
# Mapping from task to module
|
||||||
|
module = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}.get(cfg.task)
|
||||||
|
if not module:
|
||||||
|
raise SyntaxError(f"yolo task={cfg.task} is invalid. Valid tasks are: {', '.join(tasks)}\n{CLI_HELP_MSG}")
|
||||||
|
|
||||||
|
# Mapping from mode to function
|
||||||
|
func = {
|
||||||
|
"train": module.train,
|
||||||
|
"val": module.val,
|
||||||
|
"predict": module.predict,
|
||||||
|
"export": yolo.engine.exporter.export}.get(cfg.mode)
|
||||||
|
if not func:
|
||||||
|
raise SyntaxError(f"yolo mode={cfg.mode} is invalid. Valid modes are: {', '.join(modes)}\n{CLI_HELP_MSG}")
|
||||||
|
|
||||||
|
func(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
# Special modes --------------------------------------------------------------------------------------------------------
|
||||||
|
def copy_default_config():
|
||||||
|
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
||||||
|
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
||||||
|
LOGGER.info(f"{PREFIX}{DEFAULT_CFG_PATH} copied to {new_file}\n"
|
||||||
|
f"Usage for running YOLO with this new custom config:\nyolo cfg={new_file} args...")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
entrypoint()
|
||||||
|
@ -1,68 +1,68 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
# Default training settings and hyperparameters for medium-augmentation COCO training
|
# Default training settings and hyperparameters for medium-augmentation COCO training
|
||||||
|
|
||||||
task: "detect" # choices=['detect', 'segment', 'classify', 'init'] # init is a special case. Specify task to run.
|
task: "detect" # choices=['detect', 'segment', 'classify', 'init'] # init is a special case. Specify task to run.
|
||||||
mode: "train" # choices=['train', 'val', 'predict'] # mode to run task in.
|
mode: "train" # choices=['train', 'val', 'predict'] # mode to run task in.
|
||||||
|
|
||||||
# Train settings -------------------------------------------------------------------------------------------------------
|
# Train settings -------------------------------------------------------------------------------------------------------
|
||||||
model: null # i.e. yolov8n.pt, yolov8n.yaml. Path to model file
|
model: null # i.e. yolov8n.pt, yolov8n.yaml. Path to model file
|
||||||
data: null # i.e. coco128.yaml. Path to data file
|
data: null # i.e. coco128.yaml. Path to data file
|
||||||
epochs: 100 # number of epochs to train for
|
epochs: 100 # number of epochs to train for
|
||||||
patience: 50 # epochs to wait for no observable improvement for early stopping of training
|
patience: 50 # epochs to wait for no observable improvement for early stopping of training
|
||||||
batch: 16 # number of images per batch
|
batch: 16 # number of images per batch
|
||||||
imgsz: 640 # size of input images
|
imgsz: 640 # size of input images
|
||||||
save: True # save checkpoints
|
save: True # save checkpoints
|
||||||
cache: False # True/ram, disk or False. Use cache for data loading
|
cache: False # True/ram, disk or False. Use cache for data loading
|
||||||
device: null # cuda device, i.e. 0 or 0,1,2,3 or cpu. Device to run on
|
device: null # cuda device, i.e. 0 or 0,1,2,3 or cpu. Device to run on
|
||||||
workers: 8 # number of worker threads for data loading
|
workers: 8 # number of worker threads for data loading
|
||||||
project: null # project name
|
project: null # project name
|
||||||
name: null # experiment name
|
name: null # experiment name
|
||||||
exist_ok: False # whether to overwrite existing experiment
|
exist_ok: False # whether to overwrite existing experiment
|
||||||
pretrained: False # whether to use a pretrained model
|
pretrained: False # whether to use a pretrained model
|
||||||
optimizer: 'SGD' # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
|
optimizer: 'SGD' # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
|
||||||
verbose: False # whether to print verbose output
|
verbose: False # whether to print verbose output
|
||||||
seed: 0 # random seed for reproducibility
|
seed: 0 # random seed for reproducibility
|
||||||
deterministic: True # whether to enable deterministic mode
|
deterministic: True # whether to enable deterministic mode
|
||||||
single_cls: False # train multi-class data as single-class
|
single_cls: False # train multi-class data as single-class
|
||||||
image_weights: False # use weighted image selection for training
|
image_weights: False # use weighted image selection for training
|
||||||
rect: False # support rectangular training
|
rect: False # support rectangular training
|
||||||
cos_lr: False # use cosine learning rate scheduler
|
cos_lr: False # use cosine learning rate scheduler
|
||||||
close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
|
close_mosaic: 10 # disable mosaic augmentation for final 10 epochs
|
||||||
resume: False # resume training from last checkpoint
|
resume: False # resume training from last checkpoint
|
||||||
# Segmentation
|
# Segmentation
|
||||||
overlap_mask: True # masks should overlap during training
|
overlap_mask: True # masks should overlap during training
|
||||||
mask_ratio: 4 # mask downsample ratio
|
mask_ratio: 4 # mask downsample ratio
|
||||||
# Classification
|
# Classification
|
||||||
dropout: 0.0 # use dropout regularization
|
dropout: 0.0 # use dropout regularization
|
||||||
|
|
||||||
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
||||||
val: True # validate/test during training
|
val: True # validate/test during training
|
||||||
save_json: False # save results to JSON file
|
save_json: False # save results to JSON file
|
||||||
save_hybrid: False # save hybrid version of labels (labels + additional predictions)
|
save_hybrid: False # save hybrid version of labels (labels + additional predictions)
|
||||||
conf: null # object confidence threshold for detection (default 0.25 predict, 0.001 val)
|
conf: null # object confidence threshold for detection (default 0.25 predict, 0.001 val)
|
||||||
iou: 0.7 # intersection over union (IoU) threshold for NMS
|
iou: 0.7 # intersection over union (IoU) threshold for NMS
|
||||||
max_det: 300 # maximum number of detections per image
|
max_det: 300 # maximum number of detections per image
|
||||||
half: False # use half precision (FP16)
|
half: False # use half precision (FP16)
|
||||||
dnn: False # use OpenCV DNN for ONNX inference
|
dnn: False # use OpenCV DNN for ONNX inference
|
||||||
plots: True # show plots during training
|
plots: True # show plots during training
|
||||||
|
|
||||||
# Prediction settings --------------------------------------------------------------------------------------------------
|
# Prediction settings --------------------------------------------------------------------------------------------------
|
||||||
source: null # source directory for images or videos
|
source: null # source directory for images or videos
|
||||||
show: False # show results if possible
|
show: False # show results if possible
|
||||||
save_txt: False # save results as .txt file
|
save_txt: False # save results as .txt file
|
||||||
save_conf: False # save results with confidence scores
|
save_conf: False # save results with confidence scores
|
||||||
save_crop: False # save cropped images with results
|
save_crop: False # save cropped images with results
|
||||||
hide_labels: False # hide labels
|
hide_labels: False # hide labels
|
||||||
hide_conf: False # hide confidence scores
|
hide_conf: False # hide confidence scores
|
||||||
vid_stride: 1 # video frame-rate stride
|
vid_stride: 1 # video frame-rate stride
|
||||||
line_thickness: 3 # bounding box thickness (pixels)
|
line_thickness: 3 # bounding box thickness (pixels)
|
||||||
visualize: False # visualize results
|
visualize: False # visualize results
|
||||||
augment: False # apply data augmentation to images
|
augment: False # apply data augmentation to images
|
||||||
agnostic_nms: False # class-agnostic NMS
|
agnostic_nms: False # class-agnostic NMS
|
||||||
retina_masks: False # use retina masks for object detection
|
retina_masks: False # use retina masks for object detection
|
||||||
|
|
||||||
# Export settings ------------------------------------------------------------------------------------------------------
|
# Export settings ------------------------------------------------------------------------------------------------------
|
||||||
format: torchscript # format to export to
|
format: torchscript # format to export to
|
||||||
keras: False # use Keras
|
keras: False # use Keras
|
||||||
optimize: False # TorchScript: optimize for mobile
|
optimize: False # TorchScript: optimize for mobile
|
||||||
int8: False # CoreML/TF INT8 quantization
|
int8: False # CoreML/TF INT8 quantization
|
||||||
@ -100,12 +100,8 @@ mosaic: 1.0 # image mosaic (probability)
|
|||||||
mixup: 0.0 # image mixup (probability)
|
mixup: 0.0 # image mixup (probability)
|
||||||
copy_paste: 0.0 # segment copy-paste (probability)
|
copy_paste: 0.0 # segment copy-paste (probability)
|
||||||
|
|
||||||
# Hydra configs --------------------------------------------------------------------------------------------------------
|
# Custom config.yaml ---------------------------------------------------------------------------------------------------
|
||||||
cfg: null # for overriding defaults.yaml
|
cfg: null # for overriding defaults.yaml
|
||||||
hydra:
|
|
||||||
output_subdir: null # disable hydra directory creation
|
|
||||||
run:
|
|
||||||
dir: .
|
|
||||||
|
|
||||||
# Debug, do not modify -------------------------------------------------------------------------------------------------
|
# Debug, do not modify -------------------------------------------------------------------------------------------------
|
||||||
v5loader: False # use legacy YOLOv5 dataloader
|
v5loader: False # use legacy YOLOv5 dataloader
|
||||||
|
@ -1,77 +0,0 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from difflib import get_close_matches
|
|
||||||
from textwrap import dedent
|
|
||||||
|
|
||||||
import hydra
|
|
||||||
from hydra.errors import ConfigCompositionException
|
|
||||||
from omegaconf import OmegaConf, open_dict # noqa
|
|
||||||
from omegaconf.errors import ConfigAttributeError, ConfigKeyError, OmegaConfBaseException # noqa
|
|
||||||
|
|
||||||
from ultralytics.yolo.utils import LOGGER, colorstr
|
|
||||||
|
|
||||||
|
|
||||||
def override_config(overrides, cfg):
|
|
||||||
override_keys = [override.key_or_group for override in overrides]
|
|
||||||
check_config_mismatch(override_keys, cfg.keys())
|
|
||||||
for override in overrides:
|
|
||||||
if override.package is not None:
|
|
||||||
raise ConfigCompositionException(f"Override {override.input_line} looks like a config group"
|
|
||||||
f" override, but config group '{override.key_or_group}' does not exist.")
|
|
||||||
|
|
||||||
key = override.key_or_group
|
|
||||||
value = override.value()
|
|
||||||
try:
|
|
||||||
if override.is_delete():
|
|
||||||
config_val = OmegaConf.select(cfg, key, throw_on_missing=False)
|
|
||||||
if config_val is None:
|
|
||||||
raise ConfigCompositionException(f"Could not delete from config. '{override.key_or_group}'"
|
|
||||||
" does not exist.")
|
|
||||||
elif value is not None and value != config_val:
|
|
||||||
raise ConfigCompositionException("Could not delete from config. The value of"
|
|
||||||
f" '{override.key_or_group}' is {config_val} and not"
|
|
||||||
f" {value}.")
|
|
||||||
|
|
||||||
last_dot = key.rfind(".")
|
|
||||||
with open_dict(cfg):
|
|
||||||
if last_dot == -1:
|
|
||||||
del cfg[key]
|
|
||||||
else:
|
|
||||||
node = OmegaConf.select(cfg, key[:last_dot])
|
|
||||||
del node[key[last_dot + 1:]]
|
|
||||||
|
|
||||||
elif override.is_add():
|
|
||||||
if OmegaConf.select(cfg, key, throw_on_missing=False) is None or isinstance(value, (dict, list)):
|
|
||||||
OmegaConf.update(cfg, key, value, merge=True, force_add=True)
|
|
||||||
else:
|
|
||||||
assert override.input_line is not None
|
|
||||||
raise ConfigCompositionException(
|
|
||||||
dedent(f"""\
|
|
||||||
Could not append to config. An item is already at '{override.key_or_group}'.
|
|
||||||
Either remove + prefix: '{override.input_line[1:]}'
|
|
||||||
Or add a second + to add or override '{override.key_or_group}': '+{override.input_line}'
|
|
||||||
"""))
|
|
||||||
elif override.is_force_add():
|
|
||||||
OmegaConf.update(cfg, key, value, merge=True, force_add=True)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
OmegaConf.update(cfg, key, value, merge=True)
|
|
||||||
except (ConfigAttributeError, ConfigKeyError) as ex:
|
|
||||||
raise ConfigCompositionException(f"Could not override '{override.key_or_group}'."
|
|
||||||
f"\nTo append to your config use +{override.input_line}") from ex
|
|
||||||
except OmegaConfBaseException as ex:
|
|
||||||
raise ConfigCompositionException(f"Error merging override {override.input_line}").with_traceback(
|
|
||||||
sys.exc_info()[2]) from ex
|
|
||||||
|
|
||||||
|
|
||||||
def check_config_mismatch(overrides, cfg):
|
|
||||||
mismatched = [option for option in overrides if option not in cfg and 'hydra.' not in option]
|
|
||||||
|
|
||||||
for option in mismatched:
|
|
||||||
LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, cfg, 3, 0.6)}")
|
|
||||||
if mismatched:
|
|
||||||
sys.exit()
|
|
||||||
|
|
||||||
|
|
||||||
hydra._internal.config_loader_impl.ConfigLoaderImpl._apply_overrides_to_config = override_config
|
|
@ -69,8 +69,8 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
|
|||||||
augment=mode == "train", # augmentation
|
augment=mode == "train", # augmentation
|
||||||
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||||
rect=cfg.rect if mode == "train" else True, # rectangular batches
|
rect=cfg.rect if mode == "train" else True, # rectangular batches
|
||||||
cache=cfg.get("cache", None),
|
cache=cfg.cache or None,
|
||||||
single_cls=cfg.get("single_cls", False),
|
single_cls=cfg.single_cls or False,
|
||||||
stride=int(stride),
|
stride=int(stride),
|
||||||
pad=0.0 if mode == "train" else 0.5,
|
pad=0.0 if mode == "train" else 0.5,
|
||||||
prefix=colorstr(f"{mode}: "),
|
prefix=colorstr(f"{mode}: "),
|
||||||
|
@ -29,7 +29,8 @@ from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ultralytics.yolo.data.utils import check_dataset, unzip_file
|
from ultralytics.yolo.data.utils import check_dataset, unzip_file
|
||||||
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_kaggle
|
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
|
||||||
|
is_kaggle)
|
||||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
||||||
from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
|
from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
|
||||||
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
||||||
@ -493,7 +494,7 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
|
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
|
||||||
assert cache['version'] == self.cache_version # matches current version
|
assert cache['version'] == self.cache_version # matches current version
|
||||||
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
|
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
|
||||||
except (FileNotFoundError, AssertionError):
|
except (FileNotFoundError, AssertionError, AttributeError):
|
||||||
cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
|
cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
|
||||||
|
|
||||||
# Display cache
|
# Display cache
|
||||||
@ -579,16 +580,17 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||||
self.im_hw0, self.im_hw = [None] * n, [None] * n
|
self.im_hw0, self.im_hw = [None] * n, [None] * n
|
||||||
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
|
fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
|
||||||
results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
|
with (Pool if n > 10000 else ThreadPool)(NUM_THREADS) as pool:
|
||||||
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
results = pool.imap(fcn, range(n))
|
||||||
for i, x in pbar:
|
pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||||
if cache_images == 'disk':
|
for i, x in pbar:
|
||||||
b += self.npy_files[i].stat().st_size
|
if cache_images == 'disk':
|
||||||
else: # 'ram'
|
b += self.npy_files[i].stat().st_size
|
||||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
else: # 'ram'
|
||||||
b += self.ims[i].nbytes
|
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||||
pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
|
b += self.ims[i].nbytes
|
||||||
pbar.close()
|
pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
def check_cache_ram(self, safety_margin=0.1, prefix=''):
|
def check_cache_ram(self, safety_margin=0.1, prefix=''):
|
||||||
# Check image caching requirements vs available memory
|
# Check image caching requirements vs available memory
|
||||||
@ -612,11 +614,10 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
x = {} # dict
|
x = {} # dict
|
||||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||||
desc = f"{prefix}Scanning {path.parent / path.stem}..."
|
desc = f"{prefix}Scanning {path.parent / path.stem}..."
|
||||||
with Pool(NUM_THREADS) as pool:
|
total = len(self.im_files)
|
||||||
pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
|
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
|
||||||
desc=desc,
|
results = pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix)))
|
||||||
total=len(self.im_files),
|
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
|
||||||
bar_format=TQDM_BAR_FORMAT)
|
|
||||||
for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||||||
nm += nm_f
|
nm += nm_f
|
||||||
nf += nf_f
|
nf += nf_f
|
||||||
@ -627,8 +628,8 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
if msg:
|
if msg:
|
||||||
msgs.append(msg)
|
msgs.append(msg)
|
||||||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
pbar.close()
|
|
||||||
if msgs:
|
if msgs:
|
||||||
LOGGER.info('\n'.join(msgs))
|
LOGGER.info('\n'.join(msgs))
|
||||||
if nf == 0:
|
if nf == 0:
|
||||||
@ -637,12 +638,12 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
x['results'] = nf, nm, ne, nc, len(self.im_files)
|
||||||
x['msgs'] = msgs # warnings
|
x['msgs'] = msgs # warnings
|
||||||
x['version'] = self.cache_version # cache version
|
x['version'] = self.cache_version # cache version
|
||||||
try:
|
if is_dir_writeable(path.parent):
|
||||||
np.save(path, x) # save cache for next time
|
np.save(str(path), x) # save cache for next time
|
||||||
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
|
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
|
||||||
LOGGER.info(f'{prefix}New cache created: {path}')
|
LOGGER.info(f'{prefix}New cache created: {path}')
|
||||||
except Exception as e:
|
else:
|
||||||
LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}') # not writeable
|
LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable') # not writeable
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -1148,8 +1149,10 @@ class HUBDatasetStats():
|
|||||||
continue
|
continue
|
||||||
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
||||||
desc = f'{split} images'
|
desc = f'{split} images'
|
||||||
for _ in tqdm(ThreadPool(NUM_THREADS).imap(self._hub_ops, dataset.im_files), total=dataset.n, desc=desc):
|
total = dataset.n
|
||||||
pass
|
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
|
||||||
|
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=total, desc=desc):
|
||||||
|
pass
|
||||||
print(f'Done. All images saved to {self.im_dir}')
|
print(f'Done. All images saved to {self.im_dir}')
|
||||||
return self.im_dir
|
return self.im_dir
|
||||||
|
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from multiprocessing.pool import Pool
|
from multiprocessing.pool import Pool, ThreadPool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torchvision
|
import torchvision
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT
|
from ..utils import NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable
|
||||||
from .augment import *
|
from .augment import *
|
||||||
from .base import BaseDataset
|
from .base import BaseDataset
|
||||||
from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
|
from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
|
||||||
@ -50,14 +50,12 @@ class YOLODataset(BaseDataset):
|
|||||||
x = {"labels": []}
|
x = {"labels": []}
|
||||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||||
with Pool(NUM_THREADS) as pool:
|
total = len(self.im_files)
|
||||||
pbar = tqdm(
|
with (Pool if total > 10000 else ThreadPool)(NUM_THREADS) as pool:
|
||||||
pool.imap(verify_image_label,
|
results = pool.imap(func=verify_image_label,
|
||||||
zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))),
|
iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
|
||||||
desc=desc,
|
repeat(self.use_keypoints)))
|
||||||
total=len(self.im_files),
|
pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
|
||||||
bar_format=TQDM_BAR_FORMAT,
|
|
||||||
)
|
|
||||||
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||||||
nm += nm_f
|
nm += nm_f
|
||||||
nf += nf_f
|
nf += nf_f
|
||||||
@ -73,13 +71,12 @@ class YOLODataset(BaseDataset):
|
|||||||
segments=segments,
|
segments=segments,
|
||||||
keypoints=keypoint,
|
keypoints=keypoint,
|
||||||
normalized=True,
|
normalized=True,
|
||||||
bbox_format="xywh",
|
bbox_format="xywh"))
|
||||||
))
|
|
||||||
if msg:
|
if msg:
|
||||||
msgs.append(msg)
|
msgs.append(msg)
|
||||||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
pbar.close()
|
|
||||||
if msgs:
|
if msgs:
|
||||||
LOGGER.info("\n".join(msgs))
|
LOGGER.info("\n".join(msgs))
|
||||||
if nf == 0:
|
if nf == 0:
|
||||||
@ -89,13 +86,12 @@ class YOLODataset(BaseDataset):
|
|||||||
x["msgs"] = msgs # warnings
|
x["msgs"] = msgs # warnings
|
||||||
x["version"] = self.cache_version # cache version
|
x["version"] = self.cache_version # cache version
|
||||||
self.im_files = [lb["im_file"] for lb in x["labels"]]
|
self.im_files = [lb["im_file"] for lb in x["labels"]]
|
||||||
try:
|
if is_dir_writeable(path.parent):
|
||||||
np.save(path, x) # save cache for next time
|
np.save(str(path), x) # save cache for next time
|
||||||
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
||||||
LOGGER.info(f"{self.prefix}New cache created: {path}")
|
LOGGER.info(f"{self.prefix}New cache created: {path}")
|
||||||
except Exception as e:
|
else:
|
||||||
LOGGER.warning(
|
LOGGER.warning(f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable") # not writeable
|
||||||
f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}") # not writeable
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_labels(self):
|
def get_labels(self):
|
||||||
@ -105,7 +101,7 @@ class YOLODataset(BaseDataset):
|
|||||||
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
|
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict
|
||||||
assert cache["version"] == self.cache_version # matches current version
|
assert cache["version"] == self.cache_version # matches current version
|
||||||
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
||||||
except (FileNotFoundError, AssertionError):
|
except (FileNotFoundError, AssertionError, AttributeError):
|
||||||
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
||||||
|
|
||||||
# Display cache
|
# Display cache
|
||||||
|
@ -99,7 +99,7 @@ names:
|
|||||||
|
|
||||||
# Download script/URL (optional)
|
# Download script/URL (optional)
|
||||||
download: |
|
download: |
|
||||||
from ultralytics.yoloutils.downloads import download
|
from ultralytics.yolo.utils.downloads import download
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
# Download labels
|
# Download labels
|
||||||
|
@ -60,7 +60,6 @@ from collections import defaultdict
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hydra
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@ -71,7 +70,7 @@ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, Segmentati
|
|||||||
from ultralytics.yolo.configs import get_config
|
from ultralytics.yolo.configs import get_config
|
||||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
||||||
from ultralytics.yolo.data.utils import check_dataset
|
from ultralytics.yolo.data.utils import check_dataset
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, callbacks, colorstr, get_default_args, yaml_save
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get_default_args, yaml_save
|
||||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
||||||
from ultralytics.yolo.utils.files import file_size
|
from ultralytics.yolo.utils.files import file_size
|
||||||
from ultralytics.yolo.utils.ops import Profile
|
from ultralytics.yolo.utils.ops import Profile
|
||||||
@ -123,11 +122,11 @@ class Exporter:
|
|||||||
A class for exporting a model.
|
A class for exporting a model.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
args (OmegaConf): Configuration for the exporter.
|
args (SimpleNamespace): Configuration for the exporter.
|
||||||
save_dir (Path): Directory to save results.
|
save_dir (Path): Directory to save results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
def __init__(self, config=DEFAULT_CFG, overrides=None):
|
||||||
"""
|
"""
|
||||||
Initializes the Exporter class.
|
Initializes the Exporter class.
|
||||||
|
|
||||||
@ -135,8 +134,6 @@ class Exporter:
|
|||||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||||
"""
|
"""
|
||||||
if overrides is None:
|
|
||||||
overrides = {}
|
|
||||||
self.args = get_config(config, overrides)
|
self.args = get_config(config, overrides)
|
||||||
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
@ -799,8 +796,7 @@ class Exporter:
|
|||||||
callback(self)
|
callback(self)
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
def export(cfg=DEFAULT_CFG):
|
||||||
def export(cfg):
|
|
||||||
cfg.model = cfg.model or "yolov8n.yaml"
|
cfg.model = cfg.model or "yolov8n.yaml"
|
||||||
cfg.format = cfg.format or "torchscript"
|
cfg.format = cfg.format or "torchscript"
|
||||||
|
|
||||||
@ -818,7 +814,7 @@ def export(cfg):
|
|||||||
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
model = YOLO(cfg.model)
|
model = YOLO(cfg.model)
|
||||||
model.export(**cfg)
|
model.export(**vars(cfg))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -6,7 +6,7 @@ from ultralytics import yolo # noqa
|
|||||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
|
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
|
||||||
from ultralytics.yolo.configs import get_config
|
from ultralytics.yolo.configs import get_config
|
||||||
from ultralytics.yolo.engine.exporter import Exporter
|
from ultralytics.yolo.engine.exporter import Exporter
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_yaml
|
from ultralytics.yolo.utils.checks import check_yaml
|
||||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ class YOLO:
|
|||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
overrides["mode"] = "val"
|
overrides["mode"] = "val"
|
||||||
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
|
args = get_config(config=DEFAULT_CFG_PATH, overrides=overrides)
|
||||||
args.data = data or args.data
|
args.data = data or args.data
|
||||||
args.task = self.task
|
args.task = self.task
|
||||||
|
|
||||||
@ -169,7 +169,7 @@ class YOLO:
|
|||||||
|
|
||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
|
args = get_config(config=DEFAULT_CFG_PATH, overrides=overrides)
|
||||||
args.task = self.task
|
args.task = self.task
|
||||||
|
|
||||||
print(args)
|
print(args)
|
||||||
|
@ -36,7 +36,7 @@ from ultralytics.nn.autobackend import AutoBackend
|
|||||||
from ultralytics.yolo.configs import get_config
|
from ultralytics.yolo.configs import get_config
|
||||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
|
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
|
||||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, SETTINGS, callbacks, colorstr, ops
|
||||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
||||||
from ultralytics.yolo.utils.files import increment_path
|
from ultralytics.yolo.utils.files import increment_path
|
||||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||||
@ -49,7 +49,7 @@ class BasePredictor:
|
|||||||
A base class for creating predictors.
|
A base class for creating predictors.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
args (OmegaConf): Configuration for the predictor.
|
args (SimpleNamespace): Configuration for the predictor.
|
||||||
save_dir (Path): Directory to save results.
|
save_dir (Path): Directory to save results.
|
||||||
done_setup (bool): Whether the predictor has finished setup.
|
done_setup (bool): Whether the predictor has finished setup.
|
||||||
model (nn.Module): Model used for prediction.
|
model (nn.Module): Model used for prediction.
|
||||||
@ -62,7 +62,7 @@ class BasePredictor:
|
|||||||
data_path (str): Path to data.
|
data_path (str): Path to data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
def __init__(self, config=DEFAULT_CFG_PATH, overrides=None):
|
||||||
"""
|
"""
|
||||||
Initializes the BasePredictor class.
|
Initializes the BasePredictor class.
|
||||||
|
|
||||||
@ -70,8 +70,6 @@ class BasePredictor:
|
|||||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||||
"""
|
"""
|
||||||
if overrides is None:
|
|
||||||
overrides = {}
|
|
||||||
self.args = get_config(config, overrides)
|
self.args = get_config(config, overrides)
|
||||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||||
name = self.args.name or f"{self.args.mode}"
|
name = self.args.name or f"{self.args.mode}"
|
||||||
@ -157,7 +155,7 @@ class BasePredictor:
|
|||||||
if stream:
|
if stream:
|
||||||
return self.stream_inference(source, model, verbose)
|
return self.stream_inference(source, model, verbose)
|
||||||
else:
|
else:
|
||||||
return list(chain(*list(self.stream_inference(source, model, verbose)))) # merge list of Result into one
|
return list(self.stream_inference(source, model, verbose)) # merge list of Result into one
|
||||||
|
|
||||||
def predict_cli(self):
|
def predict_cli(self):
|
||||||
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
|
# Method used for CLI prediction. It uses always generator as outputs as not required by CLI mode
|
||||||
@ -211,7 +209,7 @@ class BasePredictor:
|
|||||||
if self.args.save:
|
if self.args.save:
|
||||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||||
|
|
||||||
yield results
|
yield from results
|
||||||
|
|
||||||
# Print time (inference-only)
|
# Print time (inference-only)
|
||||||
if verbose:
|
if verbose:
|
||||||
|
@ -15,8 +15,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from omegaconf import OmegaConf # noqa
|
|
||||||
from omegaconf import open_dict
|
|
||||||
from torch.cuda import amp
|
from torch.cuda import amp
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import lr_scheduler
|
from torch.optim import lr_scheduler
|
||||||
@ -27,7 +25,7 @@ from ultralytics import __version__
|
|||||||
from ultralytics.nn.tasks import attempt_load_one_weight
|
from ultralytics.nn.tasks import attempt_load_one_weight
|
||||||
from ultralytics.yolo.configs import get_config
|
from ultralytics.yolo.configs import get_config
|
||||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||||
from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
||||||
yaml_save)
|
yaml_save)
|
||||||
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
|
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
|
||||||
@ -43,7 +41,7 @@ class BaseTrainer:
|
|||||||
A base class for creating trainers.
|
A base class for creating trainers.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
args (OmegaConf): Configuration for the trainer.
|
args (SimpleNamespace): Configuration for the trainer.
|
||||||
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
||||||
console (logging.Logger): Logger instance.
|
console (logging.Logger): Logger instance.
|
||||||
validator (BaseValidator): Validator instance.
|
validator (BaseValidator): Validator instance.
|
||||||
@ -73,7 +71,7 @@ class BaseTrainer:
|
|||||||
csv (Path): Path to results CSV file.
|
csv (Path): Path to results CSV file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
def __init__(self, config=DEFAULT_CFG_PATH, overrides=None):
|
||||||
"""
|
"""
|
||||||
Initializes the BaseTrainer class.
|
Initializes the BaseTrainer class.
|
||||||
|
|
||||||
@ -81,8 +79,6 @@ class BaseTrainer:
|
|||||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||||
"""
|
"""
|
||||||
if overrides is None:
|
|
||||||
overrides = {}
|
|
||||||
self.args = get_config(config, overrides)
|
self.args = get_config(config, overrides)
|
||||||
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch)
|
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch)
|
||||||
self.check_resume()
|
self.check_resume()
|
||||||
@ -95,23 +91,23 @@ class BaseTrainer:
|
|||||||
# Dirs
|
# Dirs
|
||||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||||
name = self.args.name or f"{self.args.mode}"
|
name = self.args.name or f"{self.args.mode}"
|
||||||
self.save_dir = Path(
|
if hasattr(self.args, 'save_dir'):
|
||||||
self.args.get(
|
self.save_dir = Path(self.args.save_dir)
|
||||||
"save_dir",
|
else:
|
||||||
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)))
|
self.save_dir = Path(
|
||||||
|
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True))
|
||||||
self.wdir = self.save_dir / 'weights' # weights dir
|
self.wdir = self.save_dir / 'weights' # weights dir
|
||||||
if RANK in {-1, 0}:
|
if RANK in {-1, 0}:
|
||||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||||
with open_dict(self.args):
|
self.args.save_dir = str(self.save_dir)
|
||||||
self.args.save_dir = str(self.save_dir)
|
yaml_save(self.save_dir / 'args.yaml', vars(self.args)) # save run args
|
||||||
yaml_save(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) # save run args
|
|
||||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
||||||
|
|
||||||
self.batch_size = self.args.batch
|
self.batch_size = self.args.batch
|
||||||
self.epochs = self.args.epochs
|
self.epochs = self.args.epochs
|
||||||
self.start_epoch = 0
|
self.start_epoch = 0
|
||||||
if RANK == -1:
|
if RANK == -1:
|
||||||
print_args(dict(self.args))
|
print_args(vars(self.args))
|
||||||
|
|
||||||
# Device
|
# Device
|
||||||
self.amp = self.device.type != 'cpu'
|
self.amp = self.device.type != 'cpu'
|
||||||
@ -373,7 +369,7 @@ class BaseTrainer:
|
|||||||
'ema': deepcopy(self.ema.ema).half(),
|
'ema': deepcopy(self.ema.ema).half(),
|
||||||
'updates': self.ema.updates,
|
'updates': self.ema.updates,
|
||||||
'optimizer': self.optimizer.state_dict(),
|
'optimizer': self.optimizer.state_dict(),
|
||||||
'train_args': self.args,
|
'train_args': vars(self.args), # save as dict
|
||||||
'date': datetime.now().isoformat(),
|
'date': datetime.now().isoformat(),
|
||||||
'version': __version__}
|
'version': __version__}
|
||||||
|
|
||||||
|
@ -5,12 +5,12 @@ from collections import defaultdict
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf # noqa
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ultralytics.nn.autobackend import AutoBackend
|
from ultralytics.nn.autobackend import AutoBackend
|
||||||
|
from ultralytics.yolo.configs import get_config
|
||||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks
|
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks
|
||||||
from ultralytics.yolo.utils.checks import check_imgsz
|
from ultralytics.yolo.utils.checks import check_imgsz
|
||||||
from ultralytics.yolo.utils.files import increment_path
|
from ultralytics.yolo.utils.files import increment_path
|
||||||
from ultralytics.yolo.utils.ops import Profile
|
from ultralytics.yolo.utils.ops import Profile
|
||||||
@ -27,7 +27,7 @@ class BaseValidator:
|
|||||||
dataloader (DataLoader): Dataloader to use for validation.
|
dataloader (DataLoader): Dataloader to use for validation.
|
||||||
pbar (tqdm): Progress bar to update during validation.
|
pbar (tqdm): Progress bar to update during validation.
|
||||||
logger (logging.Logger): Logger to use for validation.
|
logger (logging.Logger): Logger to use for validation.
|
||||||
args (OmegaConf): Configuration for the validator.
|
args (SimpleNamespace): Configuration for the validator.
|
||||||
model (nn.Module): Model to validate.
|
model (nn.Module): Model to validate.
|
||||||
data (dict): Data dictionary.
|
data (dict): Data dictionary.
|
||||||
device (torch.device): Device to use for validation.
|
device (torch.device): Device to use for validation.
|
||||||
@ -47,12 +47,12 @@ class BaseValidator:
|
|||||||
save_dir (Path): Directory to save results.
|
save_dir (Path): Directory to save results.
|
||||||
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||||
logger (logging.Logger): Logger to log messages.
|
logger (logging.Logger): Logger to log messages.
|
||||||
args (OmegaConf): Configuration for the validator.
|
args (SimpleNamespace): Configuration for the validator.
|
||||||
"""
|
"""
|
||||||
self.dataloader = dataloader
|
self.dataloader = dataloader
|
||||||
self.pbar = pbar
|
self.pbar = pbar
|
||||||
self.logger = logger or LOGGER
|
self.logger = logger or LOGGER
|
||||||
self.args = args or OmegaConf.load(DEFAULT_CONFIG)
|
self.args = args or get_config(DEFAULT_CFG_PATH)
|
||||||
self.model = None
|
self.model = None
|
||||||
self.data = None
|
self.data = None
|
||||||
self.device = None
|
self.device = None
|
||||||
|
@ -8,6 +8,7 @@ import platform
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
|
import types
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
@ -22,7 +23,7 @@ import yaml
|
|||||||
# Constants
|
# Constants
|
||||||
FILE = Path(__file__).resolve()
|
FILE = Path(__file__).resolve()
|
||||||
ROOT = FILE.parents[2] # YOLO
|
ROOT = FILE.parents[2] # YOLO
|
||||||
DEFAULT_CONFIG = ROOT / "yolo/configs/default.yaml"
|
DEFAULT_CFG_PATH = ROOT / "yolo/configs/default.yaml"
|
||||||
RANK = int(os.getenv('RANK', -1))
|
RANK = int(os.getenv('RANK', -1))
|
||||||
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
||||||
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
||||||
@ -73,9 +74,10 @@ os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
|||||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # for deterministic training
|
||||||
|
|
||||||
# Default config dictionary
|
# Default config dictionary
|
||||||
with open(DEFAULT_CONFIG, errors='ignore') as f:
|
with open(DEFAULT_CFG_PATH, errors='ignore') as f:
|
||||||
DEFAULT_CONFIG_DICT = yaml.safe_load(f)
|
DEFAULT_CFG_DICT = yaml.safe_load(f)
|
||||||
DEFAULT_CONFIG_KEYS = DEFAULT_CONFIG_DICT.keys()
|
DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
|
||||||
|
DEFAULT_CFG = types.SimpleNamespace(**DEFAULT_CFG_DICT)
|
||||||
|
|
||||||
|
|
||||||
def is_colab():
|
def is_colab():
|
||||||
|
@ -28,7 +28,7 @@ def generate_ddp_file(trainer):
|
|||||||
|
|
||||||
if not trainer.resume:
|
if not trainer.resume:
|
||||||
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||||
content = f'''config = {dict(trainer.args)} \nif __name__ == "__main__":
|
content = f'''config = {vars(trainer.args)} \nif __name__ == "__main__":
|
||||||
from ultralytics.{import_path} import {trainer.__class__.__name__}
|
from ultralytics.{import_path} import {trainer.__class__.__name__}
|
||||||
|
|
||||||
trainer = {trainer.__class__.__name__}(config=config)
|
trainer = {trainer.__class__.__name__}(config=config)
|
||||||
|
@ -18,7 +18,7 @@ import torch.nn.functional as F
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
import ultralytics
|
import ultralytics
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER
|
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER
|
||||||
from ultralytics.yolo.utils.checks import git_describe
|
from ultralytics.yolo.utils.checks import git_describe
|
||||||
|
|
||||||
from .checks import check_version
|
from .checks import check_version
|
||||||
@ -288,7 +288,7 @@ def strip_optimizer(f='best.pt', s=''):
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
x = torch.load(f, map_location=torch.device('cpu'))
|
x = torch.load(f, map_location=torch.device('cpu'))
|
||||||
args = {**DEFAULT_CONFIG_DICT, **x['train_args']} # combine model args with default args, preferring model args
|
args = {**DEFAULT_CFG_DICT, **x['train_args']} # combine model args with default args, preferring model args
|
||||||
if x.get('ema'):
|
if x.get('ema'):
|
||||||
x['model'] = x['ema'] # replace model with ema
|
x['model'] = x['ema'] # replace model with ema
|
||||||
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
|
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
|
||||||
@ -297,7 +297,8 @@ def strip_optimizer(f='best.pt', s=''):
|
|||||||
x['model'].half() # to FP16
|
x['model'].half() # to FP16
|
||||||
for p in x['model'].parameters():
|
for p in x['model'].parameters():
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # strip non-default keys
|
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
||||||
|
# x['model'].args = x['train_args']
|
||||||
torch.save(x, s or f)
|
torch.save(x, s or f)
|
||||||
mb = os.path.getsize(s or f) / 1E6 # filesize
|
mb = os.path.getsize(s or f) / 1E6 # filesize
|
||||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
from ultralytics.yolo.configs import hydra_patch # noqa (patch hydra CLI)
|
|
||||||
from ultralytics.yolo.v8 import classify, detect, segment
|
from ultralytics.yolo.v8 import classify, detect, segment
|
||||||
|
|
||||||
__all__ = ["classify", "segment", "detect"]
|
__all__ = ["classify", "segment", "detect"]
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
import hydra
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||||
from ultralytics.yolo.engine.results import Results
|
from ultralytics.yolo.engine.results import Results
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory
|
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory
|
||||||
from ultralytics.yolo.utils.plotting import Annotator
|
from ultralytics.yolo.utils.plotting import Annotator
|
||||||
|
|
||||||
|
|
||||||
@ -64,8 +63,7 @@ class ClassificationPredictor(BasePredictor):
|
|||||||
return log_string
|
return log_string
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
def predict(cfg=DEFAULT_CFG):
|
||||||
def predict(cfg):
|
|
||||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
||||||
else "https://ultralytics.com/images/bus.jpg"
|
else "https://ultralytics.com/images/bus.jpg"
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
import hydra
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
@ -8,13 +7,13 @@ from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
|
|||||||
from ultralytics.yolo import v8
|
from ultralytics.yolo import v8
|
||||||
from ultralytics.yolo.data import build_classification_dataloader
|
from ultralytics.yolo.data import build_classification_dataloader
|
||||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG
|
from ultralytics.yolo.utils import DEFAULT_CFG
|
||||||
from ultralytics.yolo.utils.torch_utils import strip_optimizer
|
from ultralytics.yolo.utils.torch_utils import strip_optimizer
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTrainer(BaseTrainer):
|
class ClassificationTrainer(BaseTrainer):
|
||||||
|
|
||||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
def __init__(self, config=DEFAULT_CFG, overrides=None):
|
||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
overrides["task"] = "classify"
|
overrides["task"] = "classify"
|
||||||
@ -136,8 +135,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
# self.run_callbacks('on_fit_epoch_end')
|
# self.run_callbacks('on_fit_epoch_end')
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
def train(cfg=DEFAULT_CFG):
|
||||||
def train(cfg):
|
|
||||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||||
cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist")
|
cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist")
|
||||||
|
|
||||||
@ -152,7 +150,7 @@ def train(cfg):
|
|||||||
# trainer.train()
|
# trainer.train()
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
model = YOLO(cfg.model)
|
model = YOLO(cfg.model)
|
||||||
model.train(**cfg)
|
model.train(**vars(cfg))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
import hydra
|
|
||||||
|
|
||||||
from ultralytics.yolo.data import build_classification_dataloader
|
from ultralytics.yolo.data import build_classification_dataloader
|
||||||
from ultralytics.yolo.engine.validator import BaseValidator
|
from ultralytics.yolo.engine.validator import BaseValidator
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG
|
from ultralytics.yolo.utils import DEFAULT_CFG
|
||||||
from ultralytics.yolo.utils.metrics import ClassifyMetrics
|
from ultralytics.yolo.utils.metrics import ClassifyMetrics
|
||||||
|
|
||||||
|
|
||||||
@ -46,8 +44,7 @@ class ClassificationValidator(BaseValidator):
|
|||||||
self.logger.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
self.logger.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
def val(cfg=DEFAULT_CFG):
|
||||||
def val(cfg):
|
|
||||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||||
cfg.data = cfg.data or "imagenette160"
|
cfg.data = cfg.data or "imagenette160"
|
||||||
validator = ClassificationValidator(args=cfg)
|
validator = ClassificationValidator(args=cfg)
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
import hydra
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||||
from ultralytics.yolo.engine.results import Results
|
from ultralytics.yolo.engine.results import Results
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory, ops
|
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory, ops
|
||||||
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||||
|
|
||||||
|
|
||||||
@ -81,8 +80,7 @@ class DetectionPredictor(BasePredictor):
|
|||||||
return log_string
|
return log_string
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
def predict(cfg=DEFAULT_CFG):
|
||||||
def predict(cfg):
|
|
||||||
cfg.model = cfg.model or "yolov8n.pt"
|
cfg.model = cfg.model or "yolov8n.pt"
|
||||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
||||||
else "https://ultralytics.com/images/bus.jpg"
|
else "https://ultralytics.com/images/bus.jpg"
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
|
||||||
import hydra
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -11,7 +10,7 @@ from ultralytics.yolo import v8
|
|||||||
from ultralytics.yolo.data import build_dataloader
|
from ultralytics.yolo.data import build_dataloader
|
||||||
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
|
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
|
||||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, colorstr
|
from ultralytics.yolo.utils import DEFAULT_CFG, colorstr
|
||||||
from ultralytics.yolo.utils.loss import BboxLoss
|
from ultralytics.yolo.utils.loss import BboxLoss
|
||||||
from ultralytics.yolo.utils.ops import xywh2xyxy
|
from ultralytics.yolo.utils.ops import xywh2xyxy
|
||||||
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
||||||
@ -30,7 +29,7 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
imgsz=self.args.imgsz,
|
imgsz=self.args.imgsz,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
stride=gs,
|
stride=gs,
|
||||||
hyp=dict(self.args),
|
hyp=vars(self.args),
|
||||||
augment=mode == "train",
|
augment=mode == "train",
|
||||||
cache=self.args.cache,
|
cache=self.args.cache,
|
||||||
pad=0 if mode == "train" else 0.5,
|
pad=0 if mode == "train" else 0.5,
|
||||||
@ -195,8 +194,7 @@ class Loss:
|
|||||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
def train(cfg=DEFAULT_CFG):
|
||||||
def train(cfg):
|
|
||||||
cfg.model = cfg.model or "yolov8n.pt"
|
cfg.model = cfg.model or "yolov8n.pt"
|
||||||
cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist")
|
cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist")
|
||||||
cfg.device = cfg.device if cfg.device is not None else ''
|
cfg.device = cfg.device if cfg.device is not None else ''
|
||||||
@ -204,7 +202,7 @@ def train(cfg):
|
|||||||
# trainer.train()
|
# trainer.train()
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
model = YOLO(cfg.model)
|
model = YOLO(cfg.model)
|
||||||
model.train(**cfg)
|
model.train(**vars(cfg))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -3,14 +3,13 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hydra
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics.yolo.data import build_dataloader
|
from ultralytics.yolo.data import build_dataloader
|
||||||
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
|
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
|
||||||
from ultralytics.yolo.engine.validator import BaseValidator
|
from ultralytics.yolo.engine.validator import BaseValidator
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, colorstr, ops, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CFG, colorstr, ops, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_file, check_requirements
|
from ultralytics.yolo.utils.checks import check_file, check_requirements
|
||||||
from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
||||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||||
@ -168,7 +167,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
imgsz=self.args.imgsz,
|
imgsz=self.args.imgsz,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
stride=gs,
|
stride=gs,
|
||||||
hyp=dict(self.args),
|
hyp=vars(self.args),
|
||||||
cache=False,
|
cache=False,
|
||||||
pad=0.5,
|
pad=0.5,
|
||||||
rect=True,
|
rect=True,
|
||||||
@ -232,8 +231,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
def val(cfg=DEFAULT_CFG):
|
||||||
def val(cfg):
|
|
||||||
cfg.model = cfg.model or "yolov8n.pt"
|
cfg.model = cfg.model or "yolov8n.pt"
|
||||||
cfg.data = cfg.data or "coco128.yaml"
|
cfg.data = cfg.data or "coco128.yaml"
|
||||||
validator = DetectionValidator(args=cfg)
|
validator = DetectionValidator(args=cfg)
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
import hydra
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics.yolo.engine.results import Results
|
from ultralytics.yolo.engine.results import Results
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, is_git_directory, ops
|
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, is_git_directory, ops
|
||||||
from ultralytics.yolo.utils.plotting import colors, save_one_box
|
from ultralytics.yolo.utils.plotting import colors, save_one_box
|
||||||
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||||
|
|
||||||
@ -98,8 +97,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|||||||
return log_string
|
return log_string
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
def predict(cfg=DEFAULT_CFG):
|
||||||
def predict(cfg):
|
|
||||||
cfg.model = cfg.model or "yolov8n-seg.pt"
|
cfg.model = cfg.model or "yolov8n-seg.pt"
|
||||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" if is_git_directory() \
|
||||||
else "https://ultralytics.com/images/bus.jpg"
|
else "https://ultralytics.com/images/bus.jpg"
|
||||||
|
@ -2,13 +2,12 @@
|
|||||||
|
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
|
||||||
import hydra
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ultralytics.nn.tasks import SegmentationModel
|
from ultralytics.nn.tasks import SegmentationModel
|
||||||
from ultralytics.yolo import v8
|
from ultralytics.yolo import v8
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG
|
from ultralytics.yolo.utils import DEFAULT_CFG
|
||||||
from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh
|
from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh
|
||||||
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
from ultralytics.yolo.utils.plotting import plot_images, plot_results
|
||||||
from ultralytics.yolo.utils.tal import make_anchors
|
from ultralytics.yolo.utils.tal import make_anchors
|
||||||
@ -19,7 +18,7 @@ from ultralytics.yolo.v8.detect.train import Loss
|
|||||||
# BaseTrainer python usage
|
# BaseTrainer python usage
|
||||||
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
||||||
|
|
||||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
def __init__(self, config=DEFAULT_CFG, overrides=None):
|
||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
overrides["task"] = "segment"
|
overrides["task"] = "segment"
|
||||||
@ -141,8 +140,7 @@ class SegLoss(Loss):
|
|||||||
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
|
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
def train(cfg=DEFAULT_CFG):
|
||||||
def train(cfg):
|
|
||||||
cfg.model = cfg.model or "yolov8n-seg.pt"
|
cfg.model = cfg.model or "yolov8n-seg.pt"
|
||||||
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
|
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
|
||||||
cfg.device = cfg.device if cfg.device is not None else ''
|
cfg.device = cfg.device if cfg.device is not None else ''
|
||||||
@ -150,7 +148,7 @@ def train(cfg):
|
|||||||
# trainer.train()
|
# trainer.train()
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
model = YOLO(cfg.model)
|
model = YOLO(cfg.model)
|
||||||
model.train(**cfg)
|
model.train(**vars(cfg))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -4,12 +4,11 @@ import os
|
|||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hydra
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, NUM_THREADS, ops
|
from ultralytics.yolo.utils import DEFAULT_CFG, NUM_THREADS, ops
|
||||||
from ultralytics.yolo.utils.checks import check_requirements
|
from ultralytics.yolo.utils.checks import check_requirements
|
||||||
from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou
|
from ultralytics.yolo.utils.metrics import ConfusionMatrix, SegmentMetrics, box_iou, mask_iou
|
||||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||||
@ -243,8 +242,7 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
def val(cfg=DEFAULT_CFG):
|
||||||
def val(cfg):
|
|
||||||
cfg.data = cfg.data or "coco128-seg.yaml"
|
cfg.data = cfg.data or "coco128-seg.yaml"
|
||||||
validator = SegmentationValidator(args=cfg)
|
validator = SegmentationValidator(args=cfg)
|
||||||
validator(model=cfg.model)
|
validator(model=cfg.model)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user