mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Add Classification model YAML support (#154)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
0e5a7ae623
commit
07eab49c3d
2
.github/workflows/ci.yaml
vendored
2
.github/workflows/ci.yaml
vendored
@ -100,5 +100,5 @@ jobs:
|
|||||||
- name: Test classification
|
- name: Test classification
|
||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
run: |
|
run: |
|
||||||
yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 imgsz=32
|
yolo task=classify mode=train model=yolov8n-cls.yaml data=mnist160 epochs=1 imgsz=32
|
||||||
yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160
|
yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160
|
||||||
|
@ -3,8 +3,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
from ultralytics.yolo.utils import ROOT, SETTINGS
|
from ultralytics.yolo.utils import ROOT, SETTINGS
|
||||||
|
|
||||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
|
||||||
CFG = 'yolov8n.yaml'
|
CFG = 'yolov8n'
|
||||||
|
|
||||||
|
|
||||||
def test_checks():
|
def test_checks():
|
||||||
@ -12,25 +12,25 @@ def test_checks():
|
|||||||
|
|
||||||
|
|
||||||
# Train checks ---------------------------------------------------------------------------------------------------------
|
# Train checks ---------------------------------------------------------------------------------------------------------
|
||||||
def test_train_detect():
|
def test_train_det():
|
||||||
os.system(f'yolo mode=train task=detect model={MODEL} data=coco128.yaml imgsz=32 epochs=1')
|
os.system(f'yolo mode=train task=detect model={CFG}.yaml data=coco128.yaml imgsz=32 epochs=1')
|
||||||
|
|
||||||
|
|
||||||
def test_train_segment():
|
def test_train_seg():
|
||||||
os.system('yolo mode=train task=segment model=yolov8n-seg.yaml data=coco128-seg.yaml imgsz=32 epochs=1')
|
os.system(f'yolo mode=train task=segment model={CFG}-seg.yaml data=coco128-seg.yaml imgsz=32 epochs=1')
|
||||||
|
|
||||||
|
|
||||||
def test_train_classify():
|
def test_train_cls():
|
||||||
pass
|
os.system(f'yolo mode=train task=classify model={CFG}-cls.yaml data=imagenette160 imgsz=32 epochs=1')
|
||||||
|
|
||||||
|
|
||||||
# Val checks -----------------------------------------------------------------------------------------------------------
|
# Val checks -----------------------------------------------------------------------------------------------------------
|
||||||
def test_val_detect():
|
def test_val_detect():
|
||||||
os.system(f'yolo mode=val task=detect model={MODEL} data=coco128.yaml imgsz=32 epochs=1')
|
os.system(f'yolo mode=val task=detect model={MODEL}.pt data=coco128.yaml imgsz=32 epochs=1')
|
||||||
|
|
||||||
|
|
||||||
def test_val_segment():
|
def test_val_segment():
|
||||||
pass
|
os.system(f'yolo mode=val task=segment model={MODEL}-seg.pt data=coco128-seg.yaml imgsz=32 epochs=1')
|
||||||
|
|
||||||
|
|
||||||
def test_val_classify():
|
def test_val_classify():
|
||||||
@ -39,11 +39,11 @@ def test_val_classify():
|
|||||||
|
|
||||||
# Predict checks -------------------------------------------------------------------------------------------------------
|
# Predict checks -------------------------------------------------------------------------------------------------------
|
||||||
def test_predict_detect():
|
def test_predict_detect():
|
||||||
os.system(f"yolo mode=predict model={MODEL} source={ROOT / 'assets'}")
|
os.system(f"yolo mode=predict model={MODEL}.pt source={ROOT / 'assets'}")
|
||||||
|
|
||||||
|
|
||||||
def test_predict_segment():
|
def test_predict_segment():
|
||||||
pass
|
os.system(f"yolo mode=predict model={MODEL}-seg.pt source={ROOT / 'assets'}")
|
||||||
|
|
||||||
|
|
||||||
def test_predict_classify():
|
def test_predict_classify():
|
||||||
@ -52,11 +52,11 @@ def test_predict_classify():
|
|||||||
|
|
||||||
# Export checks --------------------------------------------------------------------------------------------------------
|
# Export checks --------------------------------------------------------------------------------------------------------
|
||||||
def test_export_detect_torchscript():
|
def test_export_detect_torchscript():
|
||||||
os.system(f'yolo mode=export model={MODEL} format=torchscript')
|
os.system(f'yolo mode=export model={MODEL}.pt format=torchscript')
|
||||||
|
|
||||||
|
|
||||||
def test_export_segment_torchscript():
|
def test_export_segment_torchscript():
|
||||||
pass
|
os.system(f'yolo mode=export model={MODEL}-seg.pt format=torchscript')
|
||||||
|
|
||||||
|
|
||||||
def test_export_classify_torchscript():
|
def test_export_classify_torchscript():
|
||||||
|
@ -71,7 +71,7 @@ def test_segment():
|
|||||||
def test_classify():
|
def test_classify():
|
||||||
overrides = {
|
overrides = {
|
||||||
"data": "imagenette160",
|
"data": "imagenette160",
|
||||||
"model": "squeezenet1_0",
|
"model": "yolov8n-cls.yaml",
|
||||||
"imgsz": 32,
|
"imgsz": 32,
|
||||||
"epochs": 1,
|
"epochs": 1,
|
||||||
"batch": 64,
|
"batch": 64,
|
||||||
|
@ -3,8 +3,8 @@ from pathlib import Path
|
|||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
from ultralytics.yolo.utils import ROOT, SETTINGS
|
from ultralytics.yolo.utils import ROOT, SETTINGS
|
||||||
|
|
||||||
CFG = 'yolov8n.yaml'
|
|
||||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
||||||
|
CFG = 'yolov8n.yaml'
|
||||||
SOURCE = ROOT / 'assets/bus.jpg'
|
SOURCE = ROOT / 'assets/bus.jpg'
|
||||||
|
|
||||||
|
|
||||||
|
@ -662,12 +662,10 @@ class Segment(Detect):
|
|||||||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
|
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
p = self.proto(x[0])
|
p = self.proto(x[0]) # mask protos
|
||||||
|
bs = p.shape[0] # batch size
|
||||||
|
|
||||||
mc = [] # mask coefficient
|
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
||||||
for i in range(self.nl):
|
|
||||||
mc.append(self.cv4[i](x[i]))
|
|
||||||
mc = torch.cat([mi.view(p.shape[0], self.nm, -1) for mi in mc], 2)
|
|
||||||
x = self.detect(self, x)
|
x = self.detect(self, x)
|
||||||
if self.training:
|
if self.training:
|
||||||
return x, mc, p
|
return x, mc, p
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import thop
|
import thop
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
|
||||||
|
|
||||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
||||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||||
@ -226,9 +224,15 @@ class SegmentationModel(DetectionModel):
|
|||||||
|
|
||||||
class ClassificationModel(BaseModel):
|
class ClassificationModel(BaseModel):
|
||||||
# YOLOv5 classification model
|
# YOLOv5 classification model
|
||||||
def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
|
def __init__(self,
|
||||||
|
cfg=None,
|
||||||
|
model=None,
|
||||||
|
ch=3,
|
||||||
|
nc=1000,
|
||||||
|
cutoff=10,
|
||||||
|
verbose=True): # yaml, model, number of classes, cutoff index
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
|
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
|
||||||
|
|
||||||
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
||||||
# Create a YOLOv5 classification model from a YOLOv5 detection model
|
# Create a YOLOv5 classification model from a YOLOv5 detection model
|
||||||
@ -246,9 +250,15 @@ class ClassificationModel(BaseModel):
|
|||||||
self.save = []
|
self.save = []
|
||||||
self.nc = nc
|
self.nc = nc
|
||||||
|
|
||||||
def _from_yaml(self, cfg):
|
def _from_yaml(self, cfg, ch, nc, verbose):
|
||||||
# TODO: Create a YOLOv5 classification model from a *.yaml file
|
self.yaml = cfg if isinstance(cfg, dict) else yaml_load(check_yaml(cfg), append_filename=True) # cfg dict
|
||||||
self.model = None
|
# Define model
|
||||||
|
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
||||||
|
if nc and nc != self.yaml['nc']:
|
||||||
|
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
||||||
|
self.yaml['nc'] = nc # override yaml value
|
||||||
|
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch], verbose=verbose) # model, savelist
|
||||||
|
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
||||||
|
|
||||||
def load(self, weights):
|
def load(self, weights):
|
||||||
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
||||||
@ -351,7 +361,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
|||||||
|
|
||||||
|
|
||||||
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
# Parse a YOLOv5 model.yaml dictionary
|
# Parse a YOLO model.yaml dictionary
|
||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
|
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
|
||||||
nc, gd, gw, act = d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
|
nc, gd, gw, act = d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
|
||||||
@ -359,7 +369,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||||||
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
|
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
|
||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
||||||
no = nc + 4 # number of outputs = classes + box
|
|
||||||
|
|
||||||
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
||||||
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
||||||
@ -370,10 +379,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||||||
|
|
||||||
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
|
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
|
||||||
if m in {
|
if m in {
|
||||||
Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, BottleneckCSP,
|
Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
|
||||||
C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
|
BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
|
||||||
c1, c2 = ch[f], args[0]
|
c1, c2 = ch[f], args[0]
|
||||||
if c2 != no: # if not output
|
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
||||||
c2 = make_divisible(c2 * gw, 8)
|
c2 = make_divisible(c2 * gw, 8)
|
||||||
|
|
||||||
args = [c1, c2, *args[1:]]
|
args = [c1, c2, *args[1:]]
|
||||||
@ -384,7 +393,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||||||
args = [ch[f]]
|
args = [ch[f]]
|
||||||
elif m is Concat:
|
elif m is Concat:
|
||||||
c2 = sum(ch[x] for x in f)
|
c2 = sum(ch[x] for x in f)
|
||||||
# TODO: channel, gw, gd
|
|
||||||
elif m in {Detect, Segment}:
|
elif m in {Detect, Segment}:
|
||||||
args.append([ch[x] for x in f])
|
args.append([ch[x] for x in f])
|
||||||
if m is Segment:
|
if m is Segment:
|
||||||
|
@ -255,12 +255,28 @@ def check_dataset_yaml(data, autodownload=True):
|
|||||||
|
|
||||||
|
|
||||||
def check_dataset(dataset: str):
|
def check_dataset(dataset: str):
|
||||||
data = Path.cwd() / "datasets" / dataset
|
"""
|
||||||
data_dir = data if data.is_dir() else (Path.cwd() / data)
|
Check a classification dataset such as Imagenet.
|
||||||
|
|
||||||
|
Copy code
|
||||||
|
This function takes a `dataset` name as input and returns a dictionary containing information about the dataset.
|
||||||
|
If the dataset is not found, it attempts to download the dataset from the internet and save it to the local file system.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (str): Name of the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
data (dict): A dictionary containing the following keys and values:
|
||||||
|
'train': Path object for the directory containing the training set of the dataset
|
||||||
|
'val': Path object for the directory containing the validation set of the dataset
|
||||||
|
'nc': Number of classes in the dataset
|
||||||
|
'names': List of class names in the dataset
|
||||||
|
"""
|
||||||
|
data_dir = (Path.cwd() / "datasets" / dataset).resolve()
|
||||||
if not data_dir.is_dir():
|
if not data_dir.is_dir():
|
||||||
LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
|
LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
|
||||||
t = time.time()
|
t = time.time()
|
||||||
if str(data) == 'imagenet':
|
if dataset == 'imagenet':
|
||||||
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
||||||
else:
|
else:
|
||||||
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
|
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
|
||||||
@ -271,5 +287,4 @@ def check_dataset(dataset: str):
|
|||||||
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
|
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
|
||||||
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
|
||||||
names = [name for name in os.listdir(data_dir / 'train') if os.path.isdir(data_dir / 'train' / name)]
|
names = [name for name in os.listdir(data_dir / 'train') if os.path.isdir(data_dir / 'train' / name)]
|
||||||
data = {"train": train_set, "val": test_set, "nc": nc, "names": names}
|
return {"train": train_set, "val": test_set, "nc": nc, "names": names}
|
||||||
return data
|
|
||||||
|
@ -103,13 +103,9 @@ class YOLO:
|
|||||||
Args:
|
Args:
|
||||||
verbose (bool): Controls verbosity.
|
verbose (bool): Controls verbosity.
|
||||||
"""
|
"""
|
||||||
if not self.model:
|
|
||||||
LOGGER.info("model not initialized!")
|
|
||||||
self.model.info(verbose=verbose)
|
self.model.info(verbose=verbose)
|
||||||
|
|
||||||
def fuse(self):
|
def fuse(self):
|
||||||
if not self.model:
|
|
||||||
LOGGER.info("model not initialized!")
|
|
||||||
self.model.fuse()
|
self.model.fuse()
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
@ -139,9 +135,6 @@ class YOLO:
|
|||||||
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
||||||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||||
"""
|
"""
|
||||||
if not self.model:
|
|
||||||
raise ModuleNotFoundError("model not initialized!")
|
|
||||||
|
|
||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
overrides["mode"] = "val"
|
overrides["mode"] = "val"
|
||||||
@ -177,8 +170,6 @@ class YOLO:
|
|||||||
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
|
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
|
||||||
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
|
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
|
||||||
"""
|
"""
|
||||||
if not self.model:
|
|
||||||
raise AttributeError("model not initialized. Use .new() or .load()")
|
|
||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
if kwargs.get("cfg"):
|
if kwargs.get("cfg"):
|
||||||
@ -193,10 +184,8 @@ class YOLO:
|
|||||||
|
|
||||||
self.trainer = self.TrainerClass(overrides=overrides)
|
self.trainer = self.TrainerClass(overrides=overrides)
|
||||||
if not overrides.get("resume"): # manually set model only if not resuming
|
if not overrides.get("resume"): # manually set model only if not resuming
|
||||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None,
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||||
cfg=self.model.yaml if self.task != "classify" else None)
|
|
||||||
self.model = self.trainer.model
|
self.model = self.trainer.model
|
||||||
|
|
||||||
self.trainer.train()
|
self.trainer.train()
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
@ -13,7 +11,9 @@ from ultralytics.yolo.utils import DEFAULT_CONFIG
|
|||||||
|
|
||||||
class ClassificationTrainer(BaseTrainer):
|
class ClassificationTrainer(BaseTrainer):
|
||||||
|
|
||||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||||
|
if overrides is None:
|
||||||
|
overrides = {}
|
||||||
overrides["task"] = "classify"
|
overrides["task"] = "classify"
|
||||||
super().__init__(config, overrides)
|
super().__init__(config, overrides)
|
||||||
|
|
||||||
@ -25,6 +25,10 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
|
||||||
|
# Update defaults
|
||||||
|
if self.args.imgsz == 640:
|
||||||
|
self.args.imgsz = 224
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def setup_model(self):
|
def setup_model(self):
|
||||||
@ -36,22 +40,17 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||||
return
|
return
|
||||||
|
|
||||||
model = self.model
|
model = str(self.model)
|
||||||
pretrained = False
|
|
||||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||||
if model.endswith(".pt"):
|
if model.endswith(".pt"):
|
||||||
model = model.split(".")[0]
|
self.model = attempt_load_weights(model, device='cpu')
|
||||||
pretrained = True
|
elif model.endswith(".yaml"):
|
||||||
else:
|
|
||||||
self.model = self.get_model(cfg=model)
|
self.model = self.get_model(cfg=model)
|
||||||
|
|
||||||
# order: check local file -> torchvision assets -> ultralytics asset
|
|
||||||
if Path(f"{model}.pt").is_file(): # local file
|
|
||||||
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
|
|
||||||
elif model in torchvision.models.__dict__:
|
elif model in torchvision.models.__dict__:
|
||||||
|
pretrained = True
|
||||||
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
||||||
else:
|
else:
|
||||||
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
|
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
||||||
|
|
||||||
return # dont return ckpt. Classification doesn't support resume
|
return # dont return ckpt. Classification doesn't support resume
|
||||||
|
|
||||||
@ -66,6 +65,10 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
batch["cls"] = batch["cls"].to(self.device)
|
batch["cls"] = batch["cls"].to(self.device)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
def progress_string(self):
|
||||||
|
return ('\n' + '%11s' *
|
||||||
|
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
|
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
|
||||||
|
|
||||||
@ -73,9 +76,6 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
loss = torch.nn.functional.cross_entropy(preds, batch["cls"])
|
loss = torch.nn.functional.cross_entropy(preds, batch["cls"])
|
||||||
return loss, loss
|
return loss, loss
|
||||||
|
|
||||||
def check_resume(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def resume_training(self, ckpt):
|
def resume_training(self, ckpt):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -85,10 +85,13 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
|
|
||||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||||
def train(cfg):
|
def train(cfg):
|
||||||
cfg.model = cfg.model or "resnet18"
|
cfg.model = cfg.model or "yolov8n-cls.yaml" # or "resnet18"
|
||||||
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
|
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
|
||||||
trainer = ClassificationTrainer(cfg)
|
# trainer = ClassificationTrainer(cfg)
|
||||||
trainer.train()
|
# trainer.train()
|
||||||
|
from ultralytics import YOLO
|
||||||
|
model = YOLO(cfg.model)
|
||||||
|
model.train(**cfg)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
23
ultralytics/yolo/v8/models/cls/yolov8l-cls.yaml
Normal file
23
ultralytics/yolo/v8/models/cls/yolov8l-cls.yaml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
nc: 1000 # number of classes
|
||||||
|
depth_multiple: 1.00 # scales module repeats
|
||||||
|
width_multiple: 1.00 # scales convolution channels
|
||||||
|
|
||||||
|
# YOLOv8.0n backbone
|
||||||
|
backbone:
|
||||||
|
# [from, repeats, module, args]
|
||||||
|
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
||||||
|
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
||||||
|
- [-1, 3, C2f, [128, True]]
|
||||||
|
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
||||||
|
- [-1, 6, C2f, [256, True]]
|
||||||
|
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
||||||
|
- [-1, 6, C2f, [512, True]]
|
||||||
|
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
||||||
|
- [-1, 3, C2f, [1024, True]]
|
||||||
|
|
||||||
|
# YOLOv8.0n head
|
||||||
|
head:
|
||||||
|
- [-1, 1, Classify, [nc]]
|
23
ultralytics/yolo/v8/models/cls/yolov8m-cls.yaml
Normal file
23
ultralytics/yolo/v8/models/cls/yolov8m-cls.yaml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
nc: 1000 # number of classes
|
||||||
|
depth_multiple: 0.67 # scales module repeats
|
||||||
|
width_multiple: 0.75 # scales convolution channels
|
||||||
|
|
||||||
|
# YOLOv8.0n backbone
|
||||||
|
backbone:
|
||||||
|
# [from, repeats, module, args]
|
||||||
|
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
||||||
|
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
||||||
|
- [-1, 3, C2f, [128, True]]
|
||||||
|
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
||||||
|
- [-1, 6, C2f, [256, True]]
|
||||||
|
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
||||||
|
- [-1, 6, C2f, [512, True]]
|
||||||
|
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
||||||
|
- [-1, 3, C2f, [1024, True]]
|
||||||
|
|
||||||
|
# YOLOv8.0n head
|
||||||
|
head:
|
||||||
|
- [-1, 1, Classify, [nc]]
|
23
ultralytics/yolo/v8/models/cls/yolov8n-cls.yaml
Normal file
23
ultralytics/yolo/v8/models/cls/yolov8n-cls.yaml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
nc: 1000 # number of classes
|
||||||
|
depth_multiple: 0.33 # scales module repeats
|
||||||
|
width_multiple: 0.25 # scales convolution channels
|
||||||
|
|
||||||
|
# YOLOv8.0n backbone
|
||||||
|
backbone:
|
||||||
|
# [from, repeats, module, args]
|
||||||
|
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
||||||
|
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
||||||
|
- [-1, 3, C2f, [128, True]]
|
||||||
|
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
||||||
|
- [-1, 6, C2f, [256, True]]
|
||||||
|
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
||||||
|
- [-1, 6, C2f, [512, True]]
|
||||||
|
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
||||||
|
- [-1, 3, C2f, [1024, True]]
|
||||||
|
|
||||||
|
# YOLOv8.0n head
|
||||||
|
head:
|
||||||
|
- [-1, 1, Classify, [nc]]
|
23
ultralytics/yolo/v8/models/cls/yolov8s-cls.yaml
Normal file
23
ultralytics/yolo/v8/models/cls/yolov8s-cls.yaml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
nc: 1000 # number of classes
|
||||||
|
depth_multiple: 0.33 # scales module repeats
|
||||||
|
width_multiple: 0.50 # scales convolution channels
|
||||||
|
|
||||||
|
# YOLOv8.0n backbone
|
||||||
|
backbone:
|
||||||
|
# [from, repeats, module, args]
|
||||||
|
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
||||||
|
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
||||||
|
- [-1, 3, C2f, [128, True]]
|
||||||
|
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
||||||
|
- [-1, 6, C2f, [256, True]]
|
||||||
|
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
||||||
|
- [-1, 6, C2f, [512, True]]
|
||||||
|
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
||||||
|
- [-1, 3, C2f, [1024, True]]
|
||||||
|
|
||||||
|
# YOLOv8.0n head
|
||||||
|
head:
|
||||||
|
- [-1, 1, Classify, [nc]]
|
23
ultralytics/yolo/v8/models/cls/yolov8x-cls.yaml
Normal file
23
ultralytics/yolo/v8/models/cls/yolov8x-cls.yaml
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
nc: 1000 # number of classes
|
||||||
|
depth_multiple: 1.00 # scales module repeats
|
||||||
|
width_multiple: 1.25 # scales convolution channels
|
||||||
|
|
||||||
|
# YOLOv8.0n backbone
|
||||||
|
backbone:
|
||||||
|
# [from, repeats, module, args]
|
||||||
|
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
|
||||||
|
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
|
||||||
|
- [-1, 3, C2f, [128, True]]
|
||||||
|
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
|
||||||
|
- [-1, 6, C2f, [256, True]]
|
||||||
|
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
|
||||||
|
- [-1, 6, C2f, [512, True]]
|
||||||
|
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
|
||||||
|
- [-1, 3, C2f, [1024, True]]
|
||||||
|
|
||||||
|
# YOLOv8.0n head
|
||||||
|
head:
|
||||||
|
- [-1, 1, Classify, [nc]]
|
Loading…
x
Reference in New Issue
Block a user