mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
ultralytics 8.0.19
seg/det dataset warning and DDP-cls/seg fixes (#595)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: 曾逸夫(Zeng Yifu) <41098760+Zengyf-CVer@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
936414c615
commit
520825c4b2
@ -14,7 +14,7 @@ ci:
|
|||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.3.0
|
rev: v4.4.0
|
||||||
hooks:
|
hooks:
|
||||||
# - id: end-of-file-fixer
|
# - id: end-of-file-fixer
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
@ -25,14 +25,14 @@ repos:
|
|||||||
- id: check-docstring-first
|
- id: check-docstring-first
|
||||||
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v2.37.3
|
rev: v3.3.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
name: Upgrade code
|
name: Upgrade code
|
||||||
args: [ --py37-plus ]
|
args: [ --py37-plus ]
|
||||||
|
|
||||||
- repo: https://github.com/PyCQA/isort
|
- repo: https://github.com/PyCQA/isort
|
||||||
rev: 5.10.1
|
rev: 5.11.4
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
name: Sort imports
|
name: Sort imports
|
||||||
@ -59,6 +59,13 @@ repos:
|
|||||||
- id: flake8
|
- id: flake8
|
||||||
name: PEP8
|
name: PEP8
|
||||||
|
|
||||||
|
- repo: https://github.com/codespell-project/codespell
|
||||||
|
rev: v2.2.2
|
||||||
|
hooks:
|
||||||
|
- id: codespell
|
||||||
|
args:
|
||||||
|
- --ignore-words-list=crate,nd
|
||||||
|
|
||||||
#- repo: https://github.com/asottile/yesqa
|
#- repo: https://github.com/asottile/yesqa
|
||||||
# rev: v1.4.0
|
# rev: v1.4.0
|
||||||
# hooks:
|
# hooks:
|
||||||
|
@ -183,7 +183,7 @@ Default arguments can be overriden by simply passing them as arguments in the CL
|
|||||||
You can override the `default.yaml` config file entirely by passing a new file with the `cfg` arguments,
|
You can override the `default.yaml` config file entirely by passing a new file with the `cfg` arguments,
|
||||||
i.e. `cfg=custom.yaml`.
|
i.e. `cfg=custom.yaml`.
|
||||||
|
|
||||||
To do this first create a copy of `default.yaml` in your current working dir with the `yolo copy-config` command.
|
To do this first create a copy of `default.yaml` in your current working dir with the `yolo copy-cfg` command.
|
||||||
|
|
||||||
This will create `default_copy.yaml`, which you can then pass as `cfg=default_copy.yaml` along with any additional args,
|
This will create `default_copy.yaml`, which you can then pass as `cfg=default_copy.yaml` along with any additional args,
|
||||||
like `imgsz=320` in this example:
|
like `imgsz=320` in this example:
|
||||||
@ -192,6 +192,6 @@ like `imgsz=320` in this example:
|
|||||||
|
|
||||||
=== "CLI"
|
=== "CLI"
|
||||||
```bash
|
```bash
|
||||||
yolo copy-config
|
yolo copy-cfg
|
||||||
yolo cfg=default_copy.yaml imgsz=320
|
yolo cfg=default_copy.yaml imgsz=320
|
||||||
```
|
```
|
@ -638,11 +638,11 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"source": [
|
"source": [
|
||||||
"# Load YOLOv8n-cls, train it on imagenette160 for 3 epochs and predict an image with it\n",
|
"# Load YOLOv8n-cls, train it on mnist160 for 3 epochs and predict an image with it\n",
|
||||||
"from ultralytics import YOLO\n",
|
"from ultralytics import YOLO\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model = YOLO('yolov8n-cls.pt') # load a pretrained YOLOv8n classification model\n",
|
"model = YOLO('yolov8n-cls.pt') # load a pretrained YOLOv8n classification model\n",
|
||||||
"model.train(data='imagenette160', epochs=3) # train the model\n",
|
"model.train(data='mnist160', epochs=3) # train the model\n",
|
||||||
"model('https://ultralytics.com/images/bus.jpg') # predict on an image"
|
"model('https://ultralytics.com/images/bus.jpg') # predict on an image"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -3,13 +3,13 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, ROOT, SETTINGS
|
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, SETTINGS
|
||||||
from ultralytics.yolo.v8 import classify, detect, segment
|
from ultralytics.yolo.v8 import classify, detect, segment
|
||||||
|
|
||||||
CFG_DET = 'yolov8n.yaml'
|
CFG_DET = 'yolov8n.yaml'
|
||||||
CFG_SEG = 'yolov8n-seg.yaml'
|
CFG_SEG = 'yolov8n-seg.yaml'
|
||||||
CFG_CLS = 'squeezenet1_0'
|
CFG_CLS = 'squeezenet1_0'
|
||||||
CFG = get_cfg(DEFAULT_CFG_PATH)
|
CFG = get_cfg(DEFAULT_CFG)
|
||||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
|
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
|
||||||
SOURCE = ROOT / "assets"
|
SOURCE = ROOT / "assets"
|
||||||
|
|
||||||
|
@ -313,13 +313,39 @@ class ClassificationModel(BaseModel):
|
|||||||
# Functions ------------------------------------------------------------------------------------------------------------
|
# Functions ------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def torch_safe_load(weight):
|
||||||
|
"""
|
||||||
|
This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it
|
||||||
|
catches the error, logs a warning message, and attempts to install the missing module via the check_requirements()
|
||||||
|
function. After installation, the function again attempts to load the model using torch.load().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight (str): The file path of the PyTorch model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The loaded PyTorch model.
|
||||||
|
"""
|
||||||
|
from ultralytics.yolo.utils.downloads import attempt_download
|
||||||
|
|
||||||
|
file = attempt_download(weight) # search online if missing locally
|
||||||
|
try:
|
||||||
|
return torch.load(file, map_location='cpu') # load
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
if e.name == 'omegaconf': # e.name is missing module name
|
||||||
|
LOGGER.warning(f"WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements."
|
||||||
|
f"\nAutoInstall will run now for {e.name} but this feature will be removed in the future."
|
||||||
|
f"\nRecommend fixes are to train a new model using updated ultraltyics package or to "
|
||||||
|
f"download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
|
||||||
|
check_requirements(e.name) # install missing module
|
||||||
|
return torch.load(file, map_location='cpu') # load
|
||||||
|
|
||||||
|
|
||||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||||
from ultralytics.yolo.utils.downloads import attempt_download
|
|
||||||
|
|
||||||
model = Ensemble()
|
model = Ensemble()
|
||||||
for w in weights if isinstance(weights, list) else [weights]:
|
for w in weights if isinstance(weights, list) else [weights]:
|
||||||
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
ckpt = torch_safe_load(w) # load ckpt
|
||||||
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||||
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||||
|
|
||||||
@ -355,18 +381,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||||||
|
|
||||||
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
||||||
# Loads a single model weights
|
# Loads a single model weights
|
||||||
from ultralytics.yolo.utils.downloads import attempt_download
|
ckpt = torch_safe_load(weight) # load ckpt
|
||||||
|
|
||||||
weight = attempt_download(weight)
|
|
||||||
try:
|
|
||||||
ckpt = torch.load(weight, map_location='cpu') # load
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
LOGGER.warning(f"WARNING ⚠️ {weight} is deprecated as it requires omegaconf, which is now removed from "
|
|
||||||
"ultralytics requirements.\nAutoInstall will occur now but this feature will be removed for "
|
|
||||||
"omegaconf models in the future.\nPlease train a new model or download updated models "
|
|
||||||
"from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
|
|
||||||
check_requirements('omegaconf')
|
|
||||||
ckpt = torch.load(weight, map_location='cpu') # load
|
|
||||||
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||||
|
|
||||||
|
@ -611,6 +611,8 @@ class LoadImagesAndLabels(Dataset):
|
|||||||
|
|
||||||
def cache_labels(self, path=Path('./labels.cache'), prefix=''):
|
def cache_labels(self, path=Path('./labels.cache'), prefix=''):
|
||||||
# Cache dataset labels, check images and read shapes
|
# Cache dataset labels, check images and read shapes
|
||||||
|
if path.exists():
|
||||||
|
path.unlink() # remove *.cache file if exists
|
||||||
x = {} # dict
|
x = {} # dict
|
||||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||||
desc = f"{prefix}Scanning {path.parent / path.stem}..."
|
desc = f"{prefix}Scanning {path.parent / path.stem}..."
|
||||||
|
@ -47,6 +47,8 @@ class YOLODataset(BaseDataset):
|
|||||||
|
|
||||||
def cache_labels(self, path=Path("./labels.cache")):
|
def cache_labels(self, path=Path("./labels.cache")):
|
||||||
# Cache dataset labels, check images and read shapes
|
# Cache dataset labels, check images and read shapes
|
||||||
|
if path.exists():
|
||||||
|
path.unlink() # remove *.cache file if exists
|
||||||
x = {"labels": []}
|
x = {"labels": []}
|
||||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||||
@ -85,7 +87,7 @@ class YOLODataset(BaseDataset):
|
|||||||
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
||||||
x["msgs"] = msgs # warnings
|
x["msgs"] = msgs # warnings
|
||||||
x["version"] = self.cache_version # cache version
|
x["version"] = self.cache_version # cache version
|
||||||
self.im_files = [lb["im_file"] for lb in x["labels"]]
|
self.im_files = [lb["im_file"] for lb in x["labels"]] # update im_files
|
||||||
if is_dir_writeable(path.parent):
|
if is_dir_writeable(path.parent):
|
||||||
np.save(str(path), x) # save cache for next time
|
np.save(str(path), x) # save cache for next time
|
||||||
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
||||||
@ -116,6 +118,17 @@ class YOLODataset(BaseDataset):
|
|||||||
# Read cache
|
# Read cache
|
||||||
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
||||||
labels = cache["labels"]
|
labels = cache["labels"]
|
||||||
|
|
||||||
|
# Check if the dataset is all boxes or all segments
|
||||||
|
len_boxes = sum(len(lb["bboxes"]) for lb in labels)
|
||||||
|
len_segments = sum(len(lb["segments"]) for lb in labels)
|
||||||
|
if len_segments and len_boxes != len_segments:
|
||||||
|
LOGGER.warning(
|
||||||
|
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
|
||||||
|
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
|
||||||
|
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")
|
||||||
|
for lb in labels:
|
||||||
|
lb["segments"] = []
|
||||||
nl = len(np.concatenate([label["cls"] for label in labels], 0)) # number of labels
|
nl = len(np.concatenate([label["cls"] for label in labels], 0)) # number of labels
|
||||||
assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}"
|
assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}"
|
||||||
return labels
|
return labels
|
||||||
|
@ -14,7 +14,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from PIL import ExifTags, Image, ImageOps
|
from PIL import ExifTags, Image, ImageOps
|
||||||
|
|
||||||
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, yaml_load
|
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, emojis, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
|
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
|
||||||
from ultralytics.yolo.utils.downloads import download
|
from ultralytics.yolo.utils.downloads import download
|
||||||
from ultralytics.yolo.utils.files import unzip_file
|
from ultralytics.yolo.utils.files import unzip_file
|
||||||
@ -202,7 +202,10 @@ def check_det_dataset(dataset, autodownload=True):
|
|||||||
|
|
||||||
# Checks
|
# Checks
|
||||||
for k in 'train', 'val', 'names':
|
for k in 'train', 'val', 'names':
|
||||||
assert k in data, f"data.yaml '{k}:' field missing ❌"
|
if k not in data:
|
||||||
|
raise SyntaxError(
|
||||||
|
emojis(f"{dataset} '{k}:' key missing ❌.\n"
|
||||||
|
f"'train', 'val' and 'names' are required in data.yaml files."))
|
||||||
if isinstance(data['names'], (list, tuple)): # old array format
|
if isinstance(data['names'], (list, tuple)): # old array format
|
||||||
data['names'] = dict(enumerate(data['names'])) # convert to dict
|
data['names'] = dict(enumerate(data['names'])) # convert to dict
|
||||||
data['nc'] = len(data['names'])
|
data['nc'] = len(data['names'])
|
||||||
|
@ -388,7 +388,7 @@ class Exporter:
|
|||||||
@try_export
|
@try_export
|
||||||
def _export_engine(self, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
|
def _export_engine(self, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
|
||||||
# YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt
|
# YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt
|
||||||
assert self.im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `device==0`'
|
assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'"
|
||||||
try:
|
try:
|
||||||
import tensorrt as trt # noqa
|
import tensorrt as trt # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -53,7 +53,12 @@ class YOLO:
|
|||||||
self.overrides = {} # overrides for trainer object
|
self.overrides = {} # overrides for trainer object
|
||||||
|
|
||||||
# Load or create new YOLO model
|
# Load or create new YOLO model
|
||||||
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)
|
load_methods = {'.pt': self._load, '.yaml': self._new}
|
||||||
|
suffix = Path(model).suffix
|
||||||
|
if suffix in load_methods:
|
||||||
|
{'.pt': self._load, '.yaml': self._new}[suffix](model)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"'{suffix}' model loading not implemented")
|
||||||
|
|
||||||
def __call__(self, source=None, stream=False, verbose=False, **kwargs):
|
def __call__(self, source=None, stream=False, verbose=False, **kwargs):
|
||||||
return self.predict(source, stream, verbose, **kwargs)
|
return self.predict(source, stream, verbose, **kwargs)
|
||||||
|
@ -35,7 +35,7 @@ from ultralytics.nn.autobackend import AutoBackend
|
|||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
|
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
|
||||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, SETTINGS, callbacks, colorstr, ops
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
||||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
||||||
from ultralytics.yolo.utils.files import increment_path
|
from ultralytics.yolo.utils.files import increment_path
|
||||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||||
@ -61,12 +61,12 @@ class BasePredictor:
|
|||||||
data_path (str): Path to data.
|
data_path (str): Path to data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg=DEFAULT_CFG_PATH, overrides=None):
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||||
"""
|
"""
|
||||||
Initializes the BasePredictor class.
|
Initializes the BasePredictor class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.args = get_cfg(cfg, overrides)
|
self.args = get_cfg(cfg, overrides)
|
||||||
|
@ -24,8 +24,8 @@ from ultralytics import __version__
|
|||||||
from ultralytics.nn.tasks import attempt_load_one_weight
|
from ultralytics.nn.tasks import attempt_load_one_weight
|
||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||||
from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis,
|
||||||
emojis, yaml_save)
|
yaml_save)
|
||||||
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
|
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
|
||||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||||
@ -71,12 +71,12 @@ class BaseTrainer:
|
|||||||
csv (Path): Path to results CSV file.
|
csv (Path): Path to results CSV file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg=DEFAULT_CFG_PATH, overrides=None):
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||||
"""
|
"""
|
||||||
Initializes the BaseTrainer class.
|
Initializes the BaseTrainer class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.args = get_cfg(cfg, overrides)
|
self.args = get_cfg(cfg, overrides)
|
||||||
|
@ -10,7 +10,7 @@ from tqdm import tqdm
|
|||||||
from ultralytics.nn.autobackend import AutoBackend
|
from ultralytics.nn.autobackend import AutoBackend
|
||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, emojis
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, emojis
|
||||||
from ultralytics.yolo.utils.checks import check_imgsz
|
from ultralytics.yolo.utils.checks import check_imgsz
|
||||||
from ultralytics.yolo.utils.files import increment_path
|
from ultralytics.yolo.utils.files import increment_path
|
||||||
from ultralytics.yolo.utils.ops import Profile
|
from ultralytics.yolo.utils.ops import Profile
|
||||||
@ -52,7 +52,7 @@ class BaseValidator:
|
|||||||
self.dataloader = dataloader
|
self.dataloader = dataloader
|
||||||
self.pbar = pbar
|
self.pbar = pbar
|
||||||
self.logger = logger or LOGGER
|
self.logger = logger or LOGGER
|
||||||
self.args = args or get_cfg(DEFAULT_CFG_PATH)
|
self.args = args or get_cfg(DEFAULT_CFG)
|
||||||
self.model = None
|
self.model = None
|
||||||
self.data = None
|
self.data = None
|
||||||
self.device = None
|
self.device = None
|
||||||
|
@ -127,8 +127,7 @@ def is_colab():
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if running inside a Colab notebook, False otherwise.
|
bool: True if running inside a Colab notebook, False otherwise.
|
||||||
"""
|
"""
|
||||||
# Check if the 'google.colab' module is present in sys.modules
|
return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ
|
||||||
return 'google.colab' in sys.modules
|
|
||||||
|
|
||||||
|
|
||||||
def is_kaggle():
|
def is_kaggle():
|
||||||
|
@ -224,7 +224,7 @@ def check_file(file, suffix=''):
|
|||||||
for d in 'models', 'yolo/data': # search directories
|
for d in 'models', 'yolo/data': # search directories
|
||||||
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
|
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
|
||||||
if not files:
|
if not files:
|
||||||
raise FileNotFoundError(f"{file} does not exist")
|
raise FileNotFoundError(f"'{file}' does not exist")
|
||||||
elif len(files) > 1:
|
elif len(files) > 1:
|
||||||
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
|
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
|
||||||
return files[0] # return file
|
return files[0] # return file
|
||||||
|
@ -10,17 +10,14 @@ from . import USER_CONFIG_DIR
|
|||||||
|
|
||||||
|
|
||||||
def find_free_network_port() -> int:
|
def find_free_network_port() -> int:
|
||||||
# https://github.com/Lightning-AI/lightning/blob/master/src/lightning_lite/plugins/environments/lightning.py
|
|
||||||
"""Finds a free port on localhost.
|
"""Finds a free port on localhost.
|
||||||
|
|
||||||
It is useful in single-node training when we don't want to connect to a real main node but have to set the
|
It is useful in single-node training when we don't want to connect to a real main node but have to set the
|
||||||
`MASTER_PORT` environment variable.
|
`MASTER_PORT` environment variable.
|
||||||
"""
|
"""
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
s.bind(("", 0))
|
s.bind(('127.0.0.1', 0))
|
||||||
port = s.getsockname()[1]
|
return s.getsockname()[1] # port
|
||||||
s.close()
|
|
||||||
return port
|
|
||||||
|
|
||||||
|
|
||||||
def generate_ddp_file(trainer):
|
def generate_ddp_file(trainer):
|
||||||
|
@ -91,12 +91,10 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
|
|||||||
|
|
||||||
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|
||||||
if name in assets:
|
if name in assets:
|
||||||
url3 = 'https://drive.google.com/drive/folders/1EFQTEUeXWSFww0luse2jB9M1QNZQGwNl' # backup gdrive mirror
|
safe_download(file,
|
||||||
safe_download(
|
url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
|
||||||
file,
|
min_bytes=1E5,
|
||||||
url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
|
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag}')
|
||||||
min_bytes=1E5,
|
|
||||||
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}')
|
|
||||||
|
|
||||||
return str(file)
|
return str(file)
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ def DDP_model(model):
|
|||||||
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
||||||
|
|
||||||
|
|
||||||
def select_device(device='', batch_size=0, newline=False):
|
def select_device(device='', batch=0, newline=False):
|
||||||
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
||||||
ver = git_describe() or ultralytics.__version__ # git commit or pip package version
|
ver = git_describe() or ultralytics.__version__ # git commit or pip package version
|
||||||
s = f'Ultralytics YOLOv{ver} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
s = f'Ultralytics YOLOv{ver} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
||||||
@ -71,14 +71,15 @@ def select_device(device='', batch_size=0, newline=False):
|
|||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
||||||
elif device: # non-cpu device requested
|
elif device: # non-cpu device requested
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
||||||
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
|
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
|
||||||
f"Invalid CUDA 'device={device}' requested, use 'device=cpu' or pass valid CUDA device(s)"
|
raise ValueError(f"Invalid CUDA 'device={device}' requested, use 'device=cpu' or pass valid CUDA device(s)")
|
||||||
|
|
||||||
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
||||||
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||||
n = len(devices) # device count
|
n = len(devices) # device count
|
||||||
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
|
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
|
||||||
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
raise ValueError(f'batch={batch} is not multiple of GPU count {n}.\n'
|
||||||
|
f'Try batch={batch // n} or batch={batch // n + 1}')
|
||||||
space = ' ' * (len(s) + 1)
|
space = ' ' * (len(s) + 1)
|
||||||
for i, d in enumerate(devices):
|
for i, d in enumerate(devices):
|
||||||
p = torch.cuda.get_device_properties(i)
|
p = torch.cuda.get_device_properties(i)
|
||||||
|
@ -13,11 +13,11 @@ from ultralytics.yolo.utils.torch_utils import strip_optimizer
|
|||||||
|
|
||||||
class ClassificationTrainer(BaseTrainer):
|
class ClassificationTrainer(BaseTrainer):
|
||||||
|
|
||||||
def __init__(self, config=DEFAULT_CFG, overrides=None):
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
overrides["task"] = "classify"
|
overrides["task"] = "classify"
|
||||||
super().__init__(config, overrides)
|
super().__init__(cfg, overrides)
|
||||||
|
|
||||||
def set_model_attributes(self):
|
def set_model_attributes(self):
|
||||||
self.model.names = self.data["names"]
|
self.model.names = self.data["names"]
|
||||||
|
@ -47,7 +47,7 @@ class ClassificationValidator(BaseValidator):
|
|||||||
|
|
||||||
def val(cfg=DEFAULT_CFG):
|
def val(cfg=DEFAULT_CFG):
|
||||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||||
cfg.data = cfg.data or "imagenette160"
|
cfg.data = cfg.data or "mnist160"
|
||||||
validator = ClassificationValidator(args=cfg)
|
validator = ClassificationValidator(args=cfg)
|
||||||
validator(model=cfg.model)
|
validator(model=cfg.model)
|
||||||
|
|
||||||
|
@ -18,11 +18,11 @@ from ultralytics.yolo.v8.detect.train import Loss
|
|||||||
# BaseTrainer python usage
|
# BaseTrainer python usage
|
||||||
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
||||||
|
|
||||||
def __init__(self, config=DEFAULT_CFG, overrides=None):
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
overrides["task"] = "segment"
|
overrides["task"] = "segment"
|
||||||
super().__init__(config, overrides)
|
super().__init__(cfg, overrides)
|
||||||
|
|
||||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)
|
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user