mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Add initial model interface (#30)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
7b560f7861
commit
1054819a59
13
ultralytics/tests/test_model.py
Normal file
13
ultralytics/tests/test_model.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from ultralytics.yolo import YOLO
|
||||||
|
|
||||||
|
|
||||||
|
def test_model():
|
||||||
|
model = YOLO()
|
||||||
|
model.new("assets/dummy_model.yaml")
|
||||||
|
model.model = "squeezenet1_0" # temp solution before get_model is implemented
|
||||||
|
# model.load("yolov5n.pt")
|
||||||
|
model.train(data="imagenette160", epochs=1, lr0=0.01)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_model()
|
@ -1,4 +1,7 @@
|
|||||||
|
import ultralytics.yolo.v8 as v8
|
||||||
|
|
||||||
|
from .engine.model import YOLO
|
||||||
from .engine.trainer import BaseTrainer
|
from .engine.trainer import BaseTrainer
|
||||||
from .engine.validator import BaseValidator
|
from .engine.validator import BaseValidator
|
||||||
|
|
||||||
__all__ = ["BaseTrainer", "BaseValidator"] # allow simpler import
|
__all__ = ["BaseTrainer", "BaseValidator", "YOLO"] # allow simpler import
|
||||||
|
@ -728,7 +728,7 @@ def classify_albumentations(
|
|||||||
if vflip > 0:
|
if vflip > 0:
|
||||||
T += [A.VerticalFlip(p=vflip)]
|
T += [A.VerticalFlip(p=vflip)]
|
||||||
if jitter > 0:
|
if jitter > 0:
|
||||||
color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue
|
color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, saturation, 0 hue
|
||||||
T += [A.ColorJitter(*color_jitter, 0)]
|
T += [A.ColorJitter(*color_jitter, 0)]
|
||||||
else: # Use fixed crop for eval set (reproducibility)
|
else: # Use fixed crop for eval set (reproducibility)
|
||||||
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
|
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
|
||||||
|
@ -51,7 +51,8 @@ def exif_size(img):
|
|||||||
def verify_image_label(args):
|
def verify_image_label(args):
|
||||||
# Verify one image-label pair
|
# Verify one image-label pair
|
||||||
im_file, lb_file, prefix, keypoint = args
|
im_file, lb_file, prefix, keypoint = args
|
||||||
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None # number (missing, found, empty, corrupt), message, segments, keypoints
|
# number (missing, found, empty, corrupt), message, segments, keypoints
|
||||||
|
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None
|
||||||
try:
|
try:
|
||||||
# verify images
|
# verify images
|
||||||
im = Image.open(im_file)
|
im = Image.open(im_file)
|
||||||
@ -86,10 +87,10 @@ def verify_image_label(args):
|
|||||||
kpts = np.zeros((lb.shape[0], 39))
|
kpts = np.zeros((lb.shape[0], 39))
|
||||||
for i in range(len(lb)):
|
for i in range(len(lb)):
|
||||||
kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5,
|
kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5,
|
||||||
3)) # remove the occlusion paramater from the GT
|
3)) # remove the occlusion parameter from the GT
|
||||||
kpts[i] = np.hstack((lb[i, :5], kpt))
|
kpts[i] = np.hstack((lb[i, :5], kpt))
|
||||||
lb = kpts
|
lb = kpts
|
||||||
assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion paramater"
|
assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter"
|
||||||
else:
|
else:
|
||||||
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
|
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
|
||||||
assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
|
assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
|
||||||
|
63
ultralytics/yolo/engine/model.py
Normal file
63
ultralytics/yolo/engine/model.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
Top-level YOLO model interface. First principle usage example - https://github.com/ultralytics/ultralytics/issues/13
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
import ultralytics.yolo as yolo
|
||||||
|
from ultralytics.yolo.utils import LOGGER
|
||||||
|
from ultralytics.yolo.utils.checks import check_yaml
|
||||||
|
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel
|
||||||
|
|
||||||
|
# map head: [model, trainer]
|
||||||
|
MODEL_MAP = {
|
||||||
|
"Classify": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'],
|
||||||
|
"Detect": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], # temp
|
||||||
|
"Segment": []}
|
||||||
|
|
||||||
|
|
||||||
|
class YOLO:
|
||||||
|
|
||||||
|
def __init__(self, version=8) -> None:
|
||||||
|
self.version = version
|
||||||
|
self.model = None
|
||||||
|
self.trainer = None
|
||||||
|
self.pretrained_weights = None
|
||||||
|
|
||||||
|
def new(self, cfg: str):
|
||||||
|
cfg = check_yaml(cfg) # check YAML
|
||||||
|
self.model, self.trainer = self._get_model_and_trainer(cfg)
|
||||||
|
|
||||||
|
def load(self, weights, autodownload=True):
|
||||||
|
if not isinstance(self.pretrained_weights, type(None)):
|
||||||
|
LOGGER.info("Overwriting weights")
|
||||||
|
# TODO: weights = smart_file_loader(weights)
|
||||||
|
if self.model:
|
||||||
|
self.model.load(weights)
|
||||||
|
LOGGER.info("Checkpoint loaded successfully")
|
||||||
|
else:
|
||||||
|
# TODO: infer model and trainer
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.pretrained_weights = weights
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def train(self, **kwargs):
|
||||||
|
if 'data' not in kwargs:
|
||||||
|
raise Exception("data is required to train")
|
||||||
|
if not self.model:
|
||||||
|
raise Exception("model not initialized. Use .new() or .load()")
|
||||||
|
kwargs["model"] = self.model
|
||||||
|
trainer = self.trainer(overrides=kwargs)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
def _get_model_and_trainer(self, cfg):
|
||||||
|
with open(cfg, encoding='ascii', errors='ignore') as f:
|
||||||
|
cfg = yaml.safe_load(f) # model dict
|
||||||
|
model, trainer = MODEL_MAP[cfg["head"][-1][-2]]
|
||||||
|
# warning: eval is unsafe. Use with caution
|
||||||
|
trainer = eval(trainer.replace("VERSION", f"v{self.version}"))
|
||||||
|
|
||||||
|
return model(cfg), trainer
|
@ -7,7 +7,7 @@ import time
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -29,30 +29,29 @@ DEFAULT_CONFIG = "defaults.yaml"
|
|||||||
|
|
||||||
class BaseTrainer:
|
class BaseTrainer:
|
||||||
|
|
||||||
def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
|
def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG, overrides={}):
|
||||||
self.console = LOGGER
|
self.console = LOGGER
|
||||||
self.model, self.data, self.train, self.hyps = self._get_config(config)
|
self.args = self._get_config(config, overrides)
|
||||||
self.validator = None
|
self.validator = None
|
||||||
self.callbacks = defaultdict(list)
|
self.callbacks = defaultdict(list)
|
||||||
self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}") # to debug
|
self.console.info(f"Training config: \n args: \n {self.args}") # to debug
|
||||||
# Directories
|
# Directories
|
||||||
self.save_dir = increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok)
|
self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
||||||
self.wdir = self.save_dir / 'weights'
|
self.wdir = self.save_dir / 'weights'
|
||||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'
|
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'
|
||||||
|
|
||||||
# Save run settings
|
# Save run settings
|
||||||
save_yaml(self.save_dir / 'train.yaml', OmegaConf.to_container(self.train, resolve=True))
|
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
|
||||||
|
|
||||||
# device
|
# device
|
||||||
self.device = utils.torch_utils.select_device(self.train.device, self.train.batch_size)
|
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size)
|
||||||
self.console.info(f"running on device {self.device}")
|
self.console.info(f"running on device {self.device}")
|
||||||
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
||||||
|
|
||||||
# Model and Dataloaders.
|
# Model and Dataloaders.
|
||||||
self.trainset, self.testset = self.get_dataset() # initialize dataset before as nc is needed for model
|
self.trainset, self.testset = self.get_dataset(self.args.data)
|
||||||
self.model = self.get_model()
|
self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
|
||||||
self.model = self.model.to(self.device)
|
|
||||||
|
|
||||||
# epoch level metrics
|
# epoch level metrics
|
||||||
self.metrics = {} # handle metrics returned by validator
|
self.metrics = {} # handle metrics returned by validator
|
||||||
@ -63,18 +62,24 @@ class BaseTrainer:
|
|||||||
for callback, func in loggers.default_callbacks.items():
|
for callback, func in loggers.default_callbacks.items():
|
||||||
self.add_callback(callback, func)
|
self.add_callback(callback, func)
|
||||||
|
|
||||||
def _get_config(self, config: Union[str, Path, DictConfig] = None):
|
def _get_config(self, config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):
|
||||||
"""
|
"""
|
||||||
Accepts yaml file name or DictConfig containing experiment configuration.
|
Accepts yaml file name or DictConfig containing experiment configuration.
|
||||||
Returns train and hyps namespace
|
Returns training args namespace
|
||||||
:param config: Optional file name or DictConfig object
|
:param config: Optional file name or DictConfig object
|
||||||
"""
|
"""
|
||||||
try:
|
|
||||||
if isinstance(config, (str, Path)):
|
if isinstance(config, (str, Path)):
|
||||||
config = OmegaConf.load(config)
|
config = OmegaConf.load(config)
|
||||||
return config.model, config.data, config.train, config.hyps
|
elif isinstance(config, Dict):
|
||||||
except KeyError as e:
|
config = OmegaConf.create(config)
|
||||||
raise KeyError("Missing key(s) in config") from e
|
|
||||||
|
# override
|
||||||
|
if isinstance(overrides, str):
|
||||||
|
overrides = OmegaConf.load(overrides)
|
||||||
|
elif isinstance(overrides, Dict):
|
||||||
|
overrides = OmegaConf.create(overrides)
|
||||||
|
|
||||||
|
return OmegaConf.merge(config, overrides)
|
||||||
|
|
||||||
def add_callback(self, onevent: str, callback):
|
def add_callback(self, onevent: str, callback):
|
||||||
"""
|
"""
|
||||||
@ -92,7 +97,7 @@ class BaseTrainer:
|
|||||||
for callback in self.callbacks.get(onevent, []):
|
for callback in self.callbacks.get(onevent, []):
|
||||||
callback(self)
|
callback(self)
|
||||||
|
|
||||||
def run(self):
|
def train(self):
|
||||||
world_size = torch.cuda.device_count()
|
world_size = torch.cuda.device_count()
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
|
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
|
||||||
@ -109,21 +114,21 @@ class BaseTrainer:
|
|||||||
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
||||||
self.model = self.model.to(self.device)
|
self.model = self.model.to(self.device)
|
||||||
self.model = DDP(self.model, device_ids=[rank])
|
self.model = DDP(self.model, device_ids=[rank])
|
||||||
self.train.batch_size = self.train.batch_size // world_size
|
self.args.batch_size = self.args.batch_size // world_size
|
||||||
|
|
||||||
def _setup_train(self, rank):
|
def _setup_train(self, rank):
|
||||||
"""
|
"""
|
||||||
Builds dataloaders and optimizer on correct rank process
|
Builds dataloaders and optimizer on correct rank process
|
||||||
"""
|
"""
|
||||||
self.optimizer = build_optimizer(model=self.model,
|
self.optimizer = build_optimizer(model=self.model,
|
||||||
name=self.train.optimizer,
|
name=self.args.optimizer,
|
||||||
lr=self.hyps.lr0,
|
lr=self.args.lr0,
|
||||||
momentum=self.hyps.momentum,
|
momentum=self.args.momentum,
|
||||||
decay=self.hyps.weight_decay)
|
decay=self.args.weight_decay)
|
||||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank)
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank)
|
||||||
if rank in {0, -1}:
|
if rank in {0, -1}:
|
||||||
print(" Creating testloader rank :", rank)
|
print(" Creating testloader rank :", rank)
|
||||||
self.test_loader = self.get_dataloader(self.testset, batch_size=self.train.batch_size * 2, rank=rank)
|
self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank)
|
||||||
self.validator = self.get_validator()
|
self.validator = self.get_validator()
|
||||||
print("created testloader :", rank)
|
print("created testloader :", rank)
|
||||||
|
|
||||||
@ -138,7 +143,7 @@ class BaseTrainer:
|
|||||||
self.epoch_time = None
|
self.epoch_time = None
|
||||||
self.epoch_time_start = time.time()
|
self.epoch_time_start = time.time()
|
||||||
self.train_time_start = time.time()
|
self.train_time_start = time.time()
|
||||||
for epoch in range(self.train.epochs):
|
for epoch in range(self.args.epochs):
|
||||||
# callback hook. on_epoch_start
|
# callback hook. on_epoch_start
|
||||||
self.model.train()
|
self.model.train()
|
||||||
pbar = enumerate(self.train_loader)
|
pbar = enumerate(self.train_loader)
|
||||||
@ -165,7 +170,7 @@ class BaseTrainer:
|
|||||||
# log
|
# log
|
||||||
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
||||||
if rank in {-1, 0}:
|
if rank in {-1, 0}:
|
||||||
pbar.desc = f"{f'{epoch + 1}/{self.train.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
|
pbar.desc = f"{f'{epoch + 1}/{self.args.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
|
||||||
|
|
||||||
if rank in [-1, 0]:
|
if rank in [-1, 0]:
|
||||||
# validation
|
# validation
|
||||||
@ -174,7 +179,7 @@ class BaseTrainer:
|
|||||||
# callback: on_val_end()
|
# callback: on_val_end()
|
||||||
|
|
||||||
# save model
|
# save model
|
||||||
if (not self.train.nosave) or (self.epoch + 1 == self.train.epochs):
|
if (not self.args.nosave) or (self.epoch + 1 == self.args.epochs):
|
||||||
self.save_model()
|
self.save_model()
|
||||||
# callback; on_model_save
|
# callback; on_model_save
|
||||||
|
|
||||||
@ -198,7 +203,7 @@ class BaseTrainer:
|
|||||||
'ema': None, # deepcopy(ema.ema).half(),
|
'ema': None, # deepcopy(ema.ema).half(),
|
||||||
'updates': None, # ema.updates,
|
'updates': None, # ema.updates,
|
||||||
'optimizer': None, # optimizer.state_dict(),
|
'optimizer': None, # optimizer.state_dict(),
|
||||||
'train_args': self.train,
|
'train_args': self.args,
|
||||||
'date': datetime.now().isoformat()}
|
'date': datetime.now().isoformat()}
|
||||||
|
|
||||||
# Save last, best and delete
|
# Save last, best and delete
|
||||||
@ -207,22 +212,22 @@ class BaseTrainer:
|
|||||||
torch.save(ckpt, self.best)
|
torch.save(ckpt, self.best)
|
||||||
del ckpt
|
del ckpt
|
||||||
|
|
||||||
def get_dataloader(self, path):
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0):
|
||||||
"""
|
"""
|
||||||
Returns dataloader derived from torch.data.Dataloader
|
Returns dataloader derived from torch.data.Dataloader
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_dataset(self):
|
def get_dataset(self, data):
|
||||||
"""
|
"""
|
||||||
Uses self.dataset to download the dataset if needed and verify it.
|
Download the dataset if needed and verify it.
|
||||||
Returns train and val split datasets
|
Returns train and val split datasets
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_model(self):
|
def get_model(self, model, pretrained=True):
|
||||||
"""
|
"""
|
||||||
Uses self.model to load/create/download dataset for any task
|
load/create/download model for any task
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -238,7 +243,7 @@ class BaseTrainer:
|
|||||||
|
|
||||||
def preprocess_batch(self, images, labels):
|
def preprocess_batch(self, images, labels):
|
||||||
"""
|
"""
|
||||||
Allows custom preprocessing model inputs and ground truths depeding on task type
|
Allows custom preprocessing model inputs and ground truths depending on task type
|
||||||
"""
|
"""
|
||||||
return images.to(self.device, non_blocking=True), labels.to(self.device)
|
return images.to(self.device, non_blocking=True), labels.to(self.device)
|
||||||
|
|
||||||
|
@ -1,53 +1,56 @@
|
|||||||
model: null
|
model: null
|
||||||
data: null
|
data: null
|
||||||
train:
|
|
||||||
epochs: 300
|
|
||||||
batch_size: 16
|
|
||||||
img_size: 640
|
|
||||||
nosave: False
|
|
||||||
cache: False # True/ram for ram, or disc
|
|
||||||
device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
|
||||||
workers: 8
|
|
||||||
project: "ultralytics-yolo"
|
|
||||||
name: "exp" # TODO: make this informative, maybe exp{#number}_{datetime} ?
|
|
||||||
exist_ok: False
|
|
||||||
pretrained: False
|
|
||||||
optimizer: "Adam" # choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
|
|
||||||
verbose: False
|
|
||||||
seed: 0
|
|
||||||
local_rank: -1
|
|
||||||
|
|
||||||
hyps:
|
# Training options
|
||||||
lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
epochs: 300
|
||||||
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
|
batch_size: 16
|
||||||
momentum: 0.937 # SGD momentum/Adam beta1
|
img_size: 640
|
||||||
weight_decay: 0.0005 # optimizer weight decay 5e-4
|
nosave: False
|
||||||
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
cache: False # True/ram for ram, or disc
|
||||||
warmup_momentum: 0.8 # warmup initial momentum
|
device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
||||||
warmup_bias_lr: 0.1 # warmup initial bias lr
|
workers: 8
|
||||||
box: 0.05 # box loss gain
|
project: "ultralytics-yolo"
|
||||||
cls: 0.5 # cls loss gain
|
name: "exp" # TODO: make this informative, maybe exp{#number}_{datetime} ?
|
||||||
cls_pw: 1.0 # cls BCELoss positive_weight
|
exist_ok: False
|
||||||
obj: 1.0 # obj loss gain (scale with pixels)
|
pretrained: False
|
||||||
obj_pw: 1.0 # obj BCELoss positive_weight
|
optimizer: "Adam" # choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
|
||||||
iou_t: 0.20 # IoU training threshold
|
verbose: False
|
||||||
anchor_t: 4.0 # anchor-multiple threshold
|
seed: 0
|
||||||
# anchors: 3 # anchors per output layer (0 to ignore)
|
local_rank: -1
|
||||||
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
|
#-----------------------------------#
|
||||||
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
|
|
||||||
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
|
|
||||||
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
|
|
||||||
degrees: 0.0 # image rotation (+/- deg)
|
|
||||||
translate: 0.1 # image translation (+/- fraction)
|
|
||||||
scale: 0.5 # image scale (+/- gain)
|
|
||||||
shear: 0.0 # image shear (+/- deg)
|
|
||||||
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
|
||||||
flipud: 0.0 # image flip up-down (probability)
|
|
||||||
fliplr: 0.5 # image flip left-right (probability)
|
|
||||||
mosaic: 1.0 # image mosaic (probability)
|
|
||||||
mixup: 0.0 # image mixup (probability)
|
|
||||||
copy_paste: 0.0 # segment copy-paste (probability)
|
|
||||||
|
|
||||||
|
# Hyper-parameters
|
||||||
|
lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||||
|
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
|
||||||
|
momentum: 0.937 # SGD momentum/Adam beta1
|
||||||
|
weight_decay: 0.0005 # optimizer weight decay 5e-4
|
||||||
|
warmup_epochs: 3.0 # warmup epochs (fractions ok)
|
||||||
|
warmup_momentum: 0.8 # warmup initial momentum
|
||||||
|
warmup_bias_lr: 0.1 # warmup initial bias lr
|
||||||
|
box: 0.05 # box loss gain
|
||||||
|
cls: 0.5 # cls loss gain
|
||||||
|
cls_pw: 1.0 # cls BCELoss positive_weight
|
||||||
|
obj: 1.0 # obj loss gain (scale with pixels)
|
||||||
|
obj_pw: 1.0 # obj BCELoss positive_weight
|
||||||
|
iou_t: 0.20 # IoU training threshold
|
||||||
|
anchor_t: 4.0 # anchor-multiple threshold
|
||||||
|
# anchors: 3 # anchors per output layer (0 to ignore)
|
||||||
|
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
|
||||||
|
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
|
||||||
|
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
|
||||||
|
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
|
||||||
|
degrees: 0.0 # image rotation (+/- deg)
|
||||||
|
translate: 0.1 # image translation (+/- fraction)
|
||||||
|
scale: 0.5 # image scale (+/- gain)
|
||||||
|
shear: 0.0 # image shear (+/- deg)
|
||||||
|
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
||||||
|
flipud: 0.0 # image flip up-down (probability)
|
||||||
|
fliplr: 0.5 # image flip left-right (probability)
|
||||||
|
mosaic: 1.0 # image mosaic (probability)
|
||||||
|
mixup: 0.0 # image mixup (probability)
|
||||||
|
copy_paste: 0.0 # segment copy-paste (probability)
|
||||||
|
|
||||||
|
# Hydra configs -------------------------------------
|
||||||
# to disable hydra directory creation
|
# to disable hydra directory creation
|
||||||
hydra:
|
hydra:
|
||||||
output_subdir: null
|
output_subdir: null
|
||||||
|
@ -8,7 +8,8 @@ from ultralytics.yolo.utils import LOGGER
|
|||||||
from ultralytics.yolo.utils.anchors import check_anchor_order
|
from ultralytics.yolo.utils.anchors import check_anchor_order
|
||||||
from ultralytics.yolo.utils.modeling import parse_model
|
from ultralytics.yolo.utils.modeling import parse_model
|
||||||
from ultralytics.yolo.utils.modeling.modules import *
|
from ultralytics.yolo.utils.modeling.modules import *
|
||||||
from ultralytics.yolo.utils.torch_utils import fuse_conv_and_bn, initialize_weights, model_info, scale_img, time_sync
|
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_state_dicts, model_info,
|
||||||
|
scale_img, time_sync)
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(nn.Module):
|
class BaseModel(nn.Module):
|
||||||
@ -67,6 +68,10 @@ class BaseModel(nn.Module):
|
|||||||
m.anchor_grid = list(map(fn, m.anchor_grid))
|
m.anchor_grid = list(map(fn, m.anchor_grid))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def load(self, weights):
|
||||||
|
# Force all tasks implement this function
|
||||||
|
raise NotImplementedError("This function needs to be implemented by derived classes!")
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(BaseModel):
|
class DetectionModel(BaseModel):
|
||||||
# YOLO detection model
|
# YOLO detection model
|
||||||
@ -166,6 +171,12 @@ class DetectionModel(BaseModel):
|
|||||||
b.data[:, 5:5 + m.nc] += math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum()) # cls
|
b.data[:, 5:5 + m.nc] += math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum()) # cls
|
||||||
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
||||||
|
|
||||||
|
def load(self, weights):
|
||||||
|
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
|
||||||
|
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
||||||
|
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
|
||||||
|
self.load_state_dict(csd, strict=False) # load
|
||||||
|
|
||||||
|
|
||||||
class SegmentationModel(DetectionModel):
|
class SegmentationModel(DetectionModel):
|
||||||
# YOLOv5 segmentation model
|
# YOLOv5 segmentation model
|
||||||
@ -197,3 +208,9 @@ class ClassificationModel(BaseModel):
|
|||||||
def _from_yaml(self, cfg):
|
def _from_yaml(self, cfg):
|
||||||
# Create a YOLOv5 classification model from a *.yaml file
|
# Create a YOLOv5 classification model from a *.yaml file
|
||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
|
def load(self, weights):
|
||||||
|
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
|
||||||
|
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
||||||
|
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
|
||||||
|
self.load_state_dict(csd, strict=False) # load
|
||||||
|
@ -174,3 +174,8 @@ def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
|
|||||||
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
|
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
|
||||||
|
|
||||||
return decorate
|
return decorate
|
||||||
|
|
||||||
|
|
||||||
|
def intersect_state_dicts(da, db, exclude=()):
|
||||||
|
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
||||||
|
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
from ultralytics.yolo.v8.classify import train
|
from ultralytics.yolo.v8.classify.train import ClassificationTrainer
|
||||||
|
from ultralytics.yolo.v8.classify.val import ClassificationValidator
|
||||||
|
|
||||||
__all__ = ["train"]
|
__all__ = ["train"]
|
||||||
|
@ -5,11 +5,10 @@ from pathlib import Path
|
|||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from val import ClassificationValidator
|
|
||||||
|
|
||||||
from ultralytics.yolo import BaseTrainer, v8
|
from ultralytics.yolo import v8
|
||||||
from ultralytics.yolo.data import build_classification_dataloader
|
from ultralytics.yolo.data import build_classification_dataloader
|
||||||
from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG
|
from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG, BaseTrainer
|
||||||
from ultralytics.yolo.utils.downloads import download
|
from ultralytics.yolo.utils.downloads import download
|
||||||
from ultralytics.yolo.utils.files import WorkingDirectory
|
from ultralytics.yolo.utils.files import WorkingDirectory
|
||||||
from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first
|
from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first
|
||||||
@ -18,9 +17,9 @@ from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zer
|
|||||||
# BaseTrainer python usage
|
# BaseTrainer python usage
|
||||||
class ClassificationTrainer(BaseTrainer):
|
class ClassificationTrainer(BaseTrainer):
|
||||||
|
|
||||||
def get_dataset(self):
|
def get_dataset(self, dataset):
|
||||||
# temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
|
# temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
|
||||||
data = Path("datasets") / self.data
|
data = Path("datasets") / dataset
|
||||||
with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()):
|
with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()):
|
||||||
data_dir = data if data.is_dir() else (Path.cwd() / data)
|
data_dir = data if data.is_dir() else (Path.cwd() / data)
|
||||||
if not data_dir.is_dir():
|
if not data_dir.is_dir():
|
||||||
@ -29,7 +28,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
if str(data) == 'imagenet':
|
if str(data) == 'imagenet':
|
||||||
subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
|
||||||
else:
|
else:
|
||||||
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{self.data}.zip'
|
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
|
||||||
download(url, dir=data_dir.parent)
|
download(url, dir=data_dir.parent)
|
||||||
# TODO: add colorstr
|
# TODO: add colorstr
|
||||||
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n"
|
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n"
|
||||||
@ -39,17 +38,18 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
|
|
||||||
return train_set, test_set
|
return train_set, test_set
|
||||||
|
|
||||||
def get_dataloader(self, dataset, batch_size=None, rank=0):
|
def get_dataloader(self, dataset_path, batch_size=None, rank=0):
|
||||||
return build_classification_dataloader(path=dataset, batch_size=self.train.batch_size, rank=rank)
|
return build_classification_dataloader(path=dataset_path, batch_size=self.args.batch_size, rank=rank)
|
||||||
|
|
||||||
def get_model(self):
|
def get_model(self, model, pretrained):
|
||||||
# temp. minimal. only supports torchvision models
|
# temp. minimal. only supports torchvision models
|
||||||
if self.model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
|
model = self.args.model
|
||||||
model = torchvision.models.__dict__[self.model](weights='IMAGENET1K_V1' if self.train.pretrained else None)
|
if model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
|
||||||
|
model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
||||||
else:
|
else:
|
||||||
raise ModuleNotFoundError(f'--model {self.model} not found.')
|
raise ModuleNotFoundError(f'--model {model} not found.')
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if not self.train.pretrained and hasattr(m, 'reset_parameters'):
|
if not pretrained and hasattr(m, 'reset_parameters'):
|
||||||
m.reset_parameters()
|
m.reset_parameters()
|
||||||
for p in model.parameters():
|
for p in model.parameters():
|
||||||
p.requires_grad = True # for training
|
p.requires_grad = True # for training
|
||||||
@ -57,7 +57,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
return ClassificationValidator(self.test_loader, self.device, logger=self.console) # validator
|
return v8.classify.ClassificationValidator(self.test_loader, self.device, logger=self.console)
|
||||||
|
|
||||||
def criterion(self, preds, targets):
|
def criterion(self, preds, targets):
|
||||||
return torch.nn.functional.cross_entropy(preds, targets)
|
return torch.nn.functional.cross_entropy(preds, targets)
|
||||||
@ -66,17 +66,17 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0])
|
@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0])
|
||||||
def train(cfg):
|
def train(cfg):
|
||||||
cfg.model = cfg.model or "squeezenet1_0"
|
cfg.model = cfg.model or "squeezenet1_0"
|
||||||
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
|
cfg.data = cfg.data or "imagenette" # or yolo.ClassificationDataset("mnist")
|
||||||
trainer = ClassificationTrainer(cfg)
|
trainer = ClassificationTrainer(cfg)
|
||||||
trainer.run()
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
CLI usage:
|
CLI usage:
|
||||||
python ../path/to/train.py train.epochs=10 train.project="name" hyps.lr0=0.1
|
python ../path/to/train.py args.epochs=10 args.project="name" hyps.lr0=0.1
|
||||||
|
|
||||||
TODO:
|
TODO:
|
||||||
Direct cli support, i.e, yolov8 classify_train train.epochs 10
|
Direct cli support, i.e, yolov8 classify_train args.epochs 10
|
||||||
"""
|
"""
|
||||||
train()
|
train()
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics import yolo
|
from ultralytics.yolo.engine.validator import BaseValidator
|
||||||
|
|
||||||
|
|
||||||
class ClassificationValidator(yolo.BaseValidator):
|
class ClassificationValidator(BaseValidator):
|
||||||
|
|
||||||
def init_metrics(self):
|
def init_metrics(self):
|
||||||
self.correct = torch.tensor([])
|
self.correct = torch.tensor([])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user