mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
ultralytics 8.0.19
seg/det dataset warning and DDP-cls/seg fixes (#595)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: 曾逸夫(Zeng Yifu) <41098760+Zengyf-CVer@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
936414c615
commit
520825c4b2
@ -14,7 +14,7 @@ ci:
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
# - id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
@ -25,14 +25,14 @@ repos:
|
||||
- id: check-docstring-first
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v2.37.3
|
||||
rev: v3.3.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
name: Upgrade code
|
||||
args: [ --py37-plus ]
|
||||
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.10.1
|
||||
rev: 5.11.4
|
||||
hooks:
|
||||
- id: isort
|
||||
name: Sort imports
|
||||
@ -59,6 +59,13 @@ repos:
|
||||
- id: flake8
|
||||
name: PEP8
|
||||
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.2.2
|
||||
hooks:
|
||||
- id: codespell
|
||||
args:
|
||||
- --ignore-words-list=crate,nd
|
||||
|
||||
#- repo: https://github.com/asottile/yesqa
|
||||
# rev: v1.4.0
|
||||
# hooks:
|
||||
|
@ -183,7 +183,7 @@ Default arguments can be overriden by simply passing them as arguments in the CL
|
||||
You can override the `default.yaml` config file entirely by passing a new file with the `cfg` arguments,
|
||||
i.e. `cfg=custom.yaml`.
|
||||
|
||||
To do this first create a copy of `default.yaml` in your current working dir with the `yolo copy-config` command.
|
||||
To do this first create a copy of `default.yaml` in your current working dir with the `yolo copy-cfg` command.
|
||||
|
||||
This will create `default_copy.yaml`, which you can then pass as `cfg=default_copy.yaml` along with any additional args,
|
||||
like `imgsz=320` in this example:
|
||||
@ -192,6 +192,6 @@ like `imgsz=320` in this example:
|
||||
|
||||
=== "CLI"
|
||||
```bash
|
||||
yolo copy-config
|
||||
yolo copy-cfg
|
||||
yolo cfg=default_copy.yaml imgsz=320
|
||||
```
|
@ -638,11 +638,11 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"# Load YOLOv8n-cls, train it on imagenette160 for 3 epochs and predict an image with it\n",
|
||||
"# Load YOLOv8n-cls, train it on mnist160 for 3 epochs and predict an image with it\n",
|
||||
"from ultralytics import YOLO\n",
|
||||
"\n",
|
||||
"model = YOLO('yolov8n-cls.pt') # load a pretrained YOLOv8n classification model\n",
|
||||
"model.train(data='imagenette160', epochs=3) # train the model\n",
|
||||
"model.train(data='mnist160', epochs=3) # train the model\n",
|
||||
"model('https://ultralytics.com/images/bus.jpg') # predict on an image"
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -3,13 +3,13 @@
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, ROOT, SETTINGS
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, SETTINGS
|
||||
from ultralytics.yolo.v8 import classify, detect, segment
|
||||
|
||||
CFG_DET = 'yolov8n.yaml'
|
||||
CFG_SEG = 'yolov8n-seg.yaml'
|
||||
CFG_CLS = 'squeezenet1_0'
|
||||
CFG = get_cfg(DEFAULT_CFG_PATH)
|
||||
CFG = get_cfg(DEFAULT_CFG)
|
||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
|
||||
SOURCE = ROOT / "assets"
|
||||
|
||||
|
@ -313,13 +313,39 @@ class ClassificationModel(BaseModel):
|
||||
# Functions ------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def torch_safe_load(weight):
|
||||
"""
|
||||
This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it
|
||||
catches the error, logs a warning message, and attempts to install the missing module via the check_requirements()
|
||||
function. After installation, the function again attempts to load the model using torch.load().
|
||||
|
||||
Args:
|
||||
weight (str): The file path of the PyTorch model.
|
||||
|
||||
Returns:
|
||||
The loaded PyTorch model.
|
||||
"""
|
||||
from ultralytics.yolo.utils.downloads import attempt_download
|
||||
|
||||
file = attempt_download(weight) # search online if missing locally
|
||||
try:
|
||||
return torch.load(file, map_location='cpu') # load
|
||||
except ModuleNotFoundError as e:
|
||||
if e.name == 'omegaconf': # e.name is missing module name
|
||||
LOGGER.warning(f"WARNING ⚠️ {weight} requires {e.name}, which is not in ultralytics requirements."
|
||||
f"\nAutoInstall will run now for {e.name} but this feature will be removed in the future."
|
||||
f"\nRecommend fixes are to train a new model using updated ultraltyics package or to "
|
||||
f"download updated models from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
|
||||
check_requirements(e.name) # install missing module
|
||||
return torch.load(file, map_location='cpu') # load
|
||||
|
||||
|
||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||
from ultralytics.yolo.utils.downloads import attempt_download
|
||||
|
||||
model = Ensemble()
|
||||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
||||
ckpt = torch_safe_load(w) # load ckpt
|
||||
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
|
||||
@ -355,18 +381,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
|
||||
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
||||
# Loads a single model weights
|
||||
from ultralytics.yolo.utils.downloads import attempt_download
|
||||
|
||||
weight = attempt_download(weight)
|
||||
try:
|
||||
ckpt = torch.load(weight, map_location='cpu') # load
|
||||
except ModuleNotFoundError:
|
||||
LOGGER.warning(f"WARNING ⚠️ {weight} is deprecated as it requires omegaconf, which is now removed from "
|
||||
"ultralytics requirements.\nAutoInstall will occur now but this feature will be removed for "
|
||||
"omegaconf models in the future.\nPlease train a new model or download updated models "
|
||||
"from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
|
||||
check_requirements('omegaconf')
|
||||
ckpt = torch.load(weight, map_location='cpu') # load
|
||||
ckpt = torch_safe_load(weight) # load ckpt
|
||||
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
|
||||
|
@ -611,6 +611,8 @@ class LoadImagesAndLabels(Dataset):
|
||||
|
||||
def cache_labels(self, path=Path('./labels.cache'), prefix=''):
|
||||
# Cache dataset labels, check images and read shapes
|
||||
if path.exists():
|
||||
path.unlink() # remove *.cache file if exists
|
||||
x = {} # dict
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f"{prefix}Scanning {path.parent / path.stem}..."
|
||||
|
@ -47,6 +47,8 @@ class YOLODataset(BaseDataset):
|
||||
|
||||
def cache_labels(self, path=Path("./labels.cache")):
|
||||
# Cache dataset labels, check images and read shapes
|
||||
if path.exists():
|
||||
path.unlink() # remove *.cache file if exists
|
||||
x = {"labels": []}
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||
@ -85,7 +87,7 @@ class YOLODataset(BaseDataset):
|
||||
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
||||
x["msgs"] = msgs # warnings
|
||||
x["version"] = self.cache_version # cache version
|
||||
self.im_files = [lb["im_file"] for lb in x["labels"]]
|
||||
self.im_files = [lb["im_file"] for lb in x["labels"]] # update im_files
|
||||
if is_dir_writeable(path.parent):
|
||||
np.save(str(path), x) # save cache for next time
|
||||
path.with_suffix(".cache.npy").rename(path) # remove .npy suffix
|
||||
@ -116,6 +118,17 @@ class YOLODataset(BaseDataset):
|
||||
# Read cache
|
||||
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
||||
labels = cache["labels"]
|
||||
|
||||
# Check if the dataset is all boxes or all segments
|
||||
len_boxes = sum(len(lb["bboxes"]) for lb in labels)
|
||||
len_segments = sum(len(lb["segments"]) for lb in labels)
|
||||
if len_segments and len_boxes != len_segments:
|
||||
LOGGER.warning(
|
||||
f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, "
|
||||
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
|
||||
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset.")
|
||||
for lb in labels:
|
||||
lb["segments"] = []
|
||||
nl = len(np.concatenate([label["cls"] for label in labels], 0)) # number of labels
|
||||
assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}"
|
||||
return labels
|
||||
|
@ -14,7 +14,7 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import ExifTags, Image, ImageOps
|
||||
|
||||
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, yaml_load
|
||||
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, emojis, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
|
||||
from ultralytics.yolo.utils.downloads import download
|
||||
from ultralytics.yolo.utils.files import unzip_file
|
||||
@ -202,7 +202,10 @@ def check_det_dataset(dataset, autodownload=True):
|
||||
|
||||
# Checks
|
||||
for k in 'train', 'val', 'names':
|
||||
assert k in data, f"data.yaml '{k}:' field missing ❌"
|
||||
if k not in data:
|
||||
raise SyntaxError(
|
||||
emojis(f"{dataset} '{k}:' key missing ❌.\n"
|
||||
f"'train', 'val' and 'names' are required in data.yaml files."))
|
||||
if isinstance(data['names'], (list, tuple)): # old array format
|
||||
data['names'] = dict(enumerate(data['names'])) # convert to dict
|
||||
data['nc'] = len(data['names'])
|
||||
|
@ -388,7 +388,7 @@ class Exporter:
|
||||
@try_export
|
||||
def _export_engine(self, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
|
||||
# YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt
|
||||
assert self.im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `device==0`'
|
||||
assert self.im.device.type != 'cpu', "export running on CPU but must be on GPU, i.e. use 'device=0'"
|
||||
try:
|
||||
import tensorrt as trt # noqa
|
||||
except ImportError:
|
||||
|
@ -53,7 +53,12 @@ class YOLO:
|
||||
self.overrides = {} # overrides for trainer object
|
||||
|
||||
# Load or create new YOLO model
|
||||
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)
|
||||
load_methods = {'.pt': self._load, '.yaml': self._new}
|
||||
suffix = Path(model).suffix
|
||||
if suffix in load_methods:
|
||||
{'.pt': self._load, '.yaml': self._new}[suffix](model)
|
||||
else:
|
||||
raise NotImplementedError(f"'{suffix}' model loading not implemented")
|
||||
|
||||
def __call__(self, source=None, stream=False, verbose=False, **kwargs):
|
||||
return self.predict(source, stream, verbose, **kwargs)
|
||||
|
@ -35,7 +35,7 @@ from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, SETTINGS, callbacks, colorstr, ops
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||
@ -61,12 +61,12 @@ class BasePredictor:
|
||||
data_path (str): Path to data.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG_PATH, overrides=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||
"""
|
||||
Initializes the BasePredictor class.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
|
@ -24,8 +24,8 @@ from ultralytics import __version__
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
||||
emojis, yaml_save)
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis,
|
||||
yaml_save)
|
||||
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
|
||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
@ -71,12 +71,12 @@ class BaseTrainer:
|
||||
csv (Path): Path to results CSV file.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG_PATH, overrides=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||
"""
|
||||
Initializes the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
|
@ -10,7 +10,7 @@ from tqdm import tqdm
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, emojis
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, emojis
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
@ -52,7 +52,7 @@ class BaseValidator:
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
self.logger = logger or LOGGER
|
||||
self.args = args or get_cfg(DEFAULT_CFG_PATH)
|
||||
self.args = args or get_cfg(DEFAULT_CFG)
|
||||
self.model = None
|
||||
self.data = None
|
||||
self.device = None
|
||||
|
@ -127,8 +127,7 @@ def is_colab():
|
||||
Returns:
|
||||
bool: True if running inside a Colab notebook, False otherwise.
|
||||
"""
|
||||
# Check if the 'google.colab' module is present in sys.modules
|
||||
return 'google.colab' in sys.modules
|
||||
return 'COLAB_RELEASE_TAG' in os.environ or 'COLAB_BACKEND_VERSION' in os.environ
|
||||
|
||||
|
||||
def is_kaggle():
|
||||
|
@ -224,7 +224,7 @@ def check_file(file, suffix=''):
|
||||
for d in 'models', 'yolo/data': # search directories
|
||||
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
|
||||
if not files:
|
||||
raise FileNotFoundError(f"{file} does not exist")
|
||||
raise FileNotFoundError(f"'{file}' does not exist")
|
||||
elif len(files) > 1:
|
||||
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
|
||||
return files[0] # return file
|
||||
|
@ -10,17 +10,14 @@ from . import USER_CONFIG_DIR
|
||||
|
||||
|
||||
def find_free_network_port() -> int:
|
||||
# https://github.com/Lightning-AI/lightning/blob/master/src/lightning_lite/plugins/environments/lightning.py
|
||||
"""Finds a free port on localhost.
|
||||
|
||||
It is useful in single-node training when we don't want to connect to a real main node but have to set the
|
||||
`MASTER_PORT` environment variable.
|
||||
"""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.bind(("", 0))
|
||||
port = s.getsockname()[1]
|
||||
s.close()
|
||||
return port
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('127.0.0.1', 0))
|
||||
return s.getsockname()[1] # port
|
||||
|
||||
|
||||
def generate_ddp_file(trainer):
|
||||
|
@ -91,12 +91,10 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
|
||||
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|
||||
if name in assets:
|
||||
url3 = 'https://drive.google.com/drive/folders/1EFQTEUeXWSFww0luse2jB9M1QNZQGwNl' # backup gdrive mirror
|
||||
safe_download(
|
||||
file,
|
||||
safe_download(file,
|
||||
url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
|
||||
min_bytes=1E5,
|
||||
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}')
|
||||
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag}')
|
||||
|
||||
return str(file)
|
||||
|
||||
|
@ -58,7 +58,7 @@ def DDP_model(model):
|
||||
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
||||
|
||||
|
||||
def select_device(device='', batch_size=0, newline=False):
|
||||
def select_device(device='', batch=0, newline=False):
|
||||
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
||||
ver = git_describe() or ultralytics.__version__ # git commit or pip package version
|
||||
s = f'Ultralytics YOLOv{ver} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
||||
@ -71,14 +71,15 @@ def select_device(device='', batch_size=0, newline=False):
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
||||
elif device: # non-cpu device requested
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
||||
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
|
||||
f"Invalid CUDA 'device={device}' requested, use 'device=cpu' or pass valid CUDA device(s)"
|
||||
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', ''))):
|
||||
raise ValueError(f"Invalid CUDA 'device={device}' requested, use 'device=cpu' or pass valid CUDA device(s)")
|
||||
|
||||
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
||||
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||
n = len(devices) # device count
|
||||
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
|
||||
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
||||
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
|
||||
raise ValueError(f'batch={batch} is not multiple of GPU count {n}.\n'
|
||||
f'Try batch={batch // n} or batch={batch // n + 1}')
|
||||
space = ' ' * (len(s) + 1)
|
||||
for i, d in enumerate(devices):
|
||||
p = torch.cuda.get_device_properties(i)
|
||||
|
@ -13,11 +13,11 @@ from ultralytics.yolo.utils.torch_utils import strip_optimizer
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def __init__(self, config=DEFAULT_CFG, overrides=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "classify"
|
||||
super().__init__(config, overrides)
|
||||
super().__init__(cfg, overrides)
|
||||
|
||||
def set_model_attributes(self):
|
||||
self.model.names = self.data["names"]
|
||||
|
@ -47,7 +47,7 @@ class ClassificationValidator(BaseValidator):
|
||||
|
||||
def val(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||
cfg.data = cfg.data or "imagenette160"
|
||||
cfg.data = cfg.data or "mnist160"
|
||||
validator = ClassificationValidator(args=cfg)
|
||||
validator(model=cfg.model)
|
||||
|
||||
|
@ -18,11 +18,11 @@ from ultralytics.yolo.v8.detect.train import Loss
|
||||
# BaseTrainer python usage
|
||||
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
||||
|
||||
def __init__(self, config=DEFAULT_CFG, overrides=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "segment"
|
||||
super().__init__(config, overrides)
|
||||
super().__init__(cfg, overrides)
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)
|
||||
|
Loading…
x
Reference in New Issue
Block a user