mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Metrics and loss structure (#28)
Co-authored-by: Ayush Chaurasia <ayush.chuararsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
d0b3c9812b
commit
c5cb76b356
@ -10,8 +10,8 @@ pip install . # (dev)
|
||||
# pip install ultralytics (production)
|
||||
```
|
||||
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
import ultralytics
|
||||
from ultralytics import HUB, YOLO
|
||||
|
@ -1,3 +1,4 @@
|
||||
from .engine.trainer import BaseTrainer
|
||||
from .engine.validator import BaseValidator
|
||||
|
||||
__all__ = ["BaseTrainer"] # allow simpler import
|
||||
__all__ = ["BaseTrainer", "BaseValidator"] # allow simpler import
|
||||
|
@ -2,13 +2,10 @@ from itertools import repeat
|
||||
from multiprocessing.pool import Pool
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils.general import LOGGER, NUM_THREADS
|
||||
from ..utils.general import NUM_THREADS
|
||||
from .augment import *
|
||||
from .base import BaseDataset
|
||||
from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
|
||||
|
@ -28,20 +28,11 @@ DEFAULT_CONFIG = "defaults.yaml"
|
||||
|
||||
class BaseTrainer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
data: str,
|
||||
criterion, # Should we create our own base loss classes? yolo.losses -> v8.losses.clfLoss
|
||||
validator=None,
|
||||
config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
|
||||
def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
|
||||
self.console = LOGGER
|
||||
self.model = model
|
||||
self.data = data
|
||||
self.criterion = criterion # ComputeLoss object TODO: create yolo.Loss classes
|
||||
self.validator = val # Dummy validator
|
||||
self.model, self.data, self.train, self.hyps = self._get_config(config)
|
||||
self.validator = None
|
||||
self.callbacks = defaultdict(list)
|
||||
self.train, self.hyps = self._get_config(config)
|
||||
self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}") # to debug
|
||||
# Directories
|
||||
self.save_dir = utils.increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok)
|
||||
@ -57,7 +48,7 @@ class BaseTrainer:
|
||||
self.console.info(f"running on device {self.device}")
|
||||
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
||||
|
||||
# Model and Dataloaders. TBD: Should we move this inside trainer?
|
||||
# Model and Dataloaders.
|
||||
self.trainset, self.testset = self.get_dataset() # initialize dataset before as nc is needed for model
|
||||
self.model = self.get_model()
|
||||
self.model = self.model.to(self.device)
|
||||
@ -80,9 +71,9 @@ class BaseTrainer:
|
||||
try:
|
||||
if isinstance(config, (str, Path)):
|
||||
config = OmegaConf.load(config)
|
||||
return config.train, config.hyps
|
||||
return config.model, config.data, config.train, config.hyps
|
||||
except KeyError as e:
|
||||
raise Exception("Missing key(s) in config") from e
|
||||
raise KeyError("Missing key(s) in config") from e
|
||||
|
||||
def add_callback(self, onevent: str, callback):
|
||||
"""
|
||||
@ -131,10 +122,9 @@ class BaseTrainer:
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank)
|
||||
if rank in {0, -1}:
|
||||
print(" Creating testloader rank :", rank)
|
||||
# self.test_loader = self.get_dataloader(self.testset,
|
||||
# batch_size=self.train.batch_size*2,
|
||||
# rank=rank)
|
||||
# print("created testloader :", rank)
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=self.train.batch_size * 2, rank=rank)
|
||||
self.validator = self.get_validator()
|
||||
print("created testloader :", rank)
|
||||
|
||||
def _do_train(self, rank, world_size):
|
||||
if world_size > 1:
|
||||
@ -235,11 +225,8 @@ class BaseTrainer:
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_criterion(self, criterion):
|
||||
"""
|
||||
:param criterion: yolo.Loss object.
|
||||
"""
|
||||
self.criterion = criterion
|
||||
def get_validator(self):
|
||||
pass
|
||||
|
||||
def optimizer_step(self):
|
||||
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||
@ -265,6 +252,12 @@ class BaseTrainer:
|
||||
if not self.best_fitness or self.best_fitness < self.fitness:
|
||||
self.best_fitness = self.fitness
|
||||
|
||||
def build_targets(self, preds, targets):
|
||||
pass
|
||||
|
||||
def criterion(self, preds, targets):
|
||||
pass
|
||||
|
||||
def progress_string(self):
|
||||
"""
|
||||
Returns progress string depending on task type.
|
||||
|
@ -0,0 +1,105 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.utils import Profile, select_device
|
||||
|
||||
|
||||
class BaseValidator:
|
||||
"""
|
||||
Base validator class.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader, device='', half=False, pbar=None, logger=None):
|
||||
self.dataloader = dataloader
|
||||
self.half = half
|
||||
self.device = select_device(device, dataloader.batch_size)
|
||||
self.pbar = pbar
|
||||
self.logger = logger or logging.getLogger()
|
||||
|
||||
def __call__(self, trainer=None, model=None):
|
||||
"""
|
||||
Supports validation of a pre-trained model if passed or a model being trained
|
||||
if trainer is passed (trainer gets priority).
|
||||
"""
|
||||
training = trainer is not None
|
||||
# trainer = trainer or self.trainer_class.get_trainer()
|
||||
assert training or model is not None, "Either trainer or model is needed for validation"
|
||||
if training:
|
||||
model = trainer.model
|
||||
self.half &= self.device.type != 'cpu'
|
||||
model = model.half() if self.half else model
|
||||
else: # TODO: handle this when detectMultiBackend is supported
|
||||
# model = DetectMultiBacked(model)
|
||||
pass
|
||||
|
||||
model.eval()
|
||||
dt = Profile(), Profile(), Profile(), Profile()
|
||||
loss = 0
|
||||
n_batches = len(self.dataloader)
|
||||
desc = self.set_desc()
|
||||
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
||||
self.init_metrics()
|
||||
with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'):
|
||||
for images, labels in bar:
|
||||
# pre-process
|
||||
with dt[0]:
|
||||
images, labels = self.preprocess_batch(images, labels)
|
||||
|
||||
# inference
|
||||
with dt[1]:
|
||||
preds = model(images)
|
||||
# TODO: remember to add native augmentation support when implementing model, like:
|
||||
# preds, train_out = model(im, augment=augment)
|
||||
|
||||
# loss
|
||||
with dt[2]:
|
||||
if training:
|
||||
loss += trainer.criterion(preds, labels) / images.shape[0]
|
||||
|
||||
# pre-process predictions
|
||||
with dt[3]:
|
||||
preds = self.preprocess_preds(preds)
|
||||
|
||||
self.update_metrics(preds, labels)
|
||||
|
||||
stats = self.get_stats()
|
||||
self.check_stats(stats)
|
||||
|
||||
self.print_results()
|
||||
|
||||
# print speeds
|
||||
if not training:
|
||||
t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
|
||||
# shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
|
||||
self.logger.info(
|
||||
'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t)
|
||||
|
||||
# TODO: implement save json
|
||||
|
||||
return stats
|
||||
|
||||
def preprocess_batch(self, images, labels):
|
||||
return images.to(self.device, non_blocking=True), labels.to(self.device)
|
||||
|
||||
def preprocess_preds(self, preds):
|
||||
return preds
|
||||
|
||||
def init_metrics(self):
|
||||
pass
|
||||
|
||||
def update_metrics(self, preds, targets):
|
||||
pass
|
||||
|
||||
def get_stats(self):
|
||||
pass
|
||||
|
||||
def check_stats(self, stats):
|
||||
pass
|
||||
|
||||
def print_results(self):
|
||||
pass
|
||||
|
||||
def set_desc(self):
|
||||
pass
|
@ -1,4 +1,4 @@
|
||||
from .general import WorkingDirectory, check_version, download, increment_path, save_yaml
|
||||
from .general import Profile, WorkingDirectory, check_version, download, increment_path, save_yaml
|
||||
from .torch_utils import LOCAL_RANK, RANK, WORLD_SIZE, DDP_model, select_device, torch_distributed_zero_first
|
||||
|
||||
__all__ = [
|
||||
@ -8,6 +8,7 @@ __all__ = [
|
||||
"WorkingDirectory",
|
||||
"download",
|
||||
"check_version",
|
||||
"Profile",
|
||||
# torch
|
||||
"torch_distributed_zero_first",
|
||||
"LOCAL_RANK",
|
||||
|
@ -1,3 +1,5 @@
|
||||
model: null
|
||||
data: null
|
||||
train:
|
||||
epochs: 300
|
||||
batch_size: 16
|
||||
|
@ -5,6 +5,7 @@ import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import time
|
||||
import urllib
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import ThreadPool
|
||||
@ -208,7 +209,7 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
||||
return path
|
||||
|
||||
|
||||
def save_yaml(file='data.yaml', data={}):
|
||||
def save_yaml(file='data.yaml', data=None):
|
||||
# Single-line safe yaml saving
|
||||
with open(file, 'w') as f:
|
||||
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
||||
@ -278,7 +279,6 @@ class WorkingDirectory(contextlib.ContextDecorator):
|
||||
|
||||
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
||||
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
|
||||
from utils.general import LOGGER
|
||||
|
||||
file = Path(file)
|
||||
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
|
||||
@ -301,7 +301,6 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
||||
|
||||
def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
|
||||
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
|
||||
from utils.general import LOGGER
|
||||
|
||||
def github_assets(repository, version='latest'):
|
||||
# Return GitHub repo tag and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
|
||||
@ -351,3 +350,23 @@ def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
|
||||
def get_model(model: str):
|
||||
# check for local weights
|
||||
pass
|
||||
|
||||
|
||||
class Profile(contextlib.ContextDecorator):
|
||||
# YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
|
||||
def __init__(self, t=0.0):
|
||||
self.t = t
|
||||
self.cuda = torch.cuda.is_available()
|
||||
|
||||
def __enter__(self):
|
||||
self.start = self.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.dt = self.time() - self.start # delta-time
|
||||
self.t += self.dt # accumulate dt
|
||||
|
||||
def time(self):
|
||||
if self.cuda:
|
||||
torch.cuda.synchronize()
|
||||
return time.time()
|
||||
|
@ -2,6 +2,7 @@
|
||||
"""
|
||||
Model validation metrics
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
@ -4,10 +4,8 @@ from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.hub as hub
|
||||
import torchvision
|
||||
import torchvision.transforms as T
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from val import ClassificationValidator
|
||||
|
||||
from ultralytics.yolo import BaseTrainer, utils, v8
|
||||
from ultralytics.yolo.data import build_classification_dataloader
|
||||
@ -15,7 +13,7 @@ from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG
|
||||
|
||||
|
||||
# BaseTrainer python usage
|
||||
class Trainer(BaseTrainer):
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def get_dataset(self):
|
||||
# temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
|
||||
@ -55,13 +53,18 @@ class Trainer(BaseTrainer):
|
||||
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
return ClassificationValidator(self.test_loader, self.device, logger=self.console) # validator
|
||||
|
||||
def criterion(self, preds, targets):
|
||||
return torch.nn.functional.cross_entropy(preds, targets)
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0])
|
||||
def train(cfg):
|
||||
model = "squeezenet1_0"
|
||||
dataset = "imagenette160" # or yolo.ClassificationDataset("mnist")
|
||||
criterion = torch.nn.CrossEntropyLoss() # yolo.Loss object
|
||||
trainer = Trainer(model, dataset, criterion, config=cfg)
|
||||
cfg.model = cfg.model or "squeezenet1_0"
|
||||
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
|
||||
trainer = ClassificationTrainer(cfg)
|
||||
trainer.run()
|
||||
|
||||
|
||||
|
18
ultralytics/yolo/v8/classify/val.py
Normal file
18
ultralytics/yolo/v8/classify/val.py
Normal file
@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
from ultralytics import yolo
|
||||
|
||||
|
||||
class ClassificationValidator(yolo.BaseValidator):
|
||||
|
||||
def init_metrics(self):
|
||||
self.correct = torch.tensor([])
|
||||
|
||||
def update_metrics(self, preds, targets):
|
||||
correct_in_batch = (targets[:, None] == preds).float()
|
||||
self.correct = torch.cat((self.correct, correct_in_batch))
|
||||
|
||||
def get_stats(self):
|
||||
acc = torch.stack((self.correct[:, 0], self.correct.max(1).values), dim=1) # (top1, top5) accuracy
|
||||
top1, top5 = acc.mean(0).tolist()
|
||||
return {"top1": top1, "top5": top5, "fitness": top5}
|
Loading…
x
Reference in New Issue
Block a user