mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Metrics and loss structure (#28)
Co-authored-by: Ayush Chaurasia <ayush.chuararsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
d0b3c9812b
commit
c5cb76b356
@ -10,8 +10,8 @@ pip install . # (dev)
|
|||||||
# pip install ultralytics (production)
|
# pip install ultralytics (production)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import ultralytics
|
import ultralytics
|
||||||
from ultralytics import HUB, YOLO
|
from ultralytics import HUB, YOLO
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
from .engine.trainer import BaseTrainer
|
from .engine.trainer import BaseTrainer
|
||||||
|
from .engine.validator import BaseValidator
|
||||||
|
|
||||||
__all__ = ["BaseTrainer"] # allow simpler import
|
__all__ = ["BaseTrainer", "BaseValidator"] # allow simpler import
|
||||||
|
@ -2,13 +2,10 @@ from itertools import repeat
|
|||||||
from multiprocessing.pool import Pool
|
from multiprocessing.pool import Pool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torchvision
|
import torchvision
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..utils.general import LOGGER, NUM_THREADS
|
from ..utils.general import NUM_THREADS
|
||||||
from .augment import *
|
from .augment import *
|
||||||
from .base import BaseDataset
|
from .base import BaseDataset
|
||||||
from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
|
from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
|
||||||
|
@ -28,20 +28,11 @@ DEFAULT_CONFIG = "defaults.yaml"
|
|||||||
|
|
||||||
class BaseTrainer:
|
class BaseTrainer:
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
data: str,
|
|
||||||
criterion, # Should we create our own base loss classes? yolo.losses -> v8.losses.clfLoss
|
|
||||||
validator=None,
|
|
||||||
config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
|
|
||||||
self.console = LOGGER
|
self.console = LOGGER
|
||||||
self.model = model
|
self.model, self.data, self.train, self.hyps = self._get_config(config)
|
||||||
self.data = data
|
self.validator = None
|
||||||
self.criterion = criterion # ComputeLoss object TODO: create yolo.Loss classes
|
|
||||||
self.validator = val # Dummy validator
|
|
||||||
self.callbacks = defaultdict(list)
|
self.callbacks = defaultdict(list)
|
||||||
self.train, self.hyps = self._get_config(config)
|
|
||||||
self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}") # to debug
|
self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}") # to debug
|
||||||
# Directories
|
# Directories
|
||||||
self.save_dir = utils.increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok)
|
self.save_dir = utils.increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok)
|
||||||
@ -57,7 +48,7 @@ class BaseTrainer:
|
|||||||
self.console.info(f"running on device {self.device}")
|
self.console.info(f"running on device {self.device}")
|
||||||
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
||||||
|
|
||||||
# Model and Dataloaders. TBD: Should we move this inside trainer?
|
# Model and Dataloaders.
|
||||||
self.trainset, self.testset = self.get_dataset() # initialize dataset before as nc is needed for model
|
self.trainset, self.testset = self.get_dataset() # initialize dataset before as nc is needed for model
|
||||||
self.model = self.get_model()
|
self.model = self.get_model()
|
||||||
self.model = self.model.to(self.device)
|
self.model = self.model.to(self.device)
|
||||||
@ -80,9 +71,9 @@ class BaseTrainer:
|
|||||||
try:
|
try:
|
||||||
if isinstance(config, (str, Path)):
|
if isinstance(config, (str, Path)):
|
||||||
config = OmegaConf.load(config)
|
config = OmegaConf.load(config)
|
||||||
return config.train, config.hyps
|
return config.model, config.data, config.train, config.hyps
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise Exception("Missing key(s) in config") from e
|
raise KeyError("Missing key(s) in config") from e
|
||||||
|
|
||||||
def add_callback(self, onevent: str, callback):
|
def add_callback(self, onevent: str, callback):
|
||||||
"""
|
"""
|
||||||
@ -131,10 +122,9 @@ class BaseTrainer:
|
|||||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank)
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank)
|
||||||
if rank in {0, -1}:
|
if rank in {0, -1}:
|
||||||
print(" Creating testloader rank :", rank)
|
print(" Creating testloader rank :", rank)
|
||||||
# self.test_loader = self.get_dataloader(self.testset,
|
self.test_loader = self.get_dataloader(self.testset, batch_size=self.train.batch_size * 2, rank=rank)
|
||||||
# batch_size=self.train.batch_size*2,
|
self.validator = self.get_validator()
|
||||||
# rank=rank)
|
print("created testloader :", rank)
|
||||||
# print("created testloader :", rank)
|
|
||||||
|
|
||||||
def _do_train(self, rank, world_size):
|
def _do_train(self, rank, world_size):
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -235,11 +225,8 @@ class BaseTrainer:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def set_criterion(self, criterion):
|
def get_validator(self):
|
||||||
"""
|
pass
|
||||||
:param criterion: yolo.Loss object.
|
|
||||||
"""
|
|
||||||
self.criterion = criterion
|
|
||||||
|
|
||||||
def optimizer_step(self):
|
def optimizer_step(self):
|
||||||
self.scaler.unscale_(self.optimizer) # unscale gradients
|
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||||
@ -265,6 +252,12 @@ class BaseTrainer:
|
|||||||
if not self.best_fitness or self.best_fitness < self.fitness:
|
if not self.best_fitness or self.best_fitness < self.fitness:
|
||||||
self.best_fitness = self.fitness
|
self.best_fitness = self.fitness
|
||||||
|
|
||||||
|
def build_targets(self, preds, targets):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def criterion(self, preds, targets):
|
||||||
|
pass
|
||||||
|
|
||||||
def progress_string(self):
|
def progress_string(self):
|
||||||
"""
|
"""
|
||||||
Returns progress string depending on task type.
|
Returns progress string depending on task type.
|
||||||
|
@ -0,0 +1,105 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ultralytics.yolo.utils import Profile, select_device
|
||||||
|
|
||||||
|
|
||||||
|
class BaseValidator:
|
||||||
|
"""
|
||||||
|
Base validator class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataloader, device='', half=False, pbar=None, logger=None):
|
||||||
|
self.dataloader = dataloader
|
||||||
|
self.half = half
|
||||||
|
self.device = select_device(device, dataloader.batch_size)
|
||||||
|
self.pbar = pbar
|
||||||
|
self.logger = logger or logging.getLogger()
|
||||||
|
|
||||||
|
def __call__(self, trainer=None, model=None):
|
||||||
|
"""
|
||||||
|
Supports validation of a pre-trained model if passed or a model being trained
|
||||||
|
if trainer is passed (trainer gets priority).
|
||||||
|
"""
|
||||||
|
training = trainer is not None
|
||||||
|
# trainer = trainer or self.trainer_class.get_trainer()
|
||||||
|
assert training or model is not None, "Either trainer or model is needed for validation"
|
||||||
|
if training:
|
||||||
|
model = trainer.model
|
||||||
|
self.half &= self.device.type != 'cpu'
|
||||||
|
model = model.half() if self.half else model
|
||||||
|
else: # TODO: handle this when detectMultiBackend is supported
|
||||||
|
# model = DetectMultiBacked(model)
|
||||||
|
pass
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
dt = Profile(), Profile(), Profile(), Profile()
|
||||||
|
loss = 0
|
||||||
|
n_batches = len(self.dataloader)
|
||||||
|
desc = self.set_desc()
|
||||||
|
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
||||||
|
self.init_metrics()
|
||||||
|
with torch.cuda.amp.autocast(enabled=self.device.type != 'cpu'):
|
||||||
|
for images, labels in bar:
|
||||||
|
# pre-process
|
||||||
|
with dt[0]:
|
||||||
|
images, labels = self.preprocess_batch(images, labels)
|
||||||
|
|
||||||
|
# inference
|
||||||
|
with dt[1]:
|
||||||
|
preds = model(images)
|
||||||
|
# TODO: remember to add native augmentation support when implementing model, like:
|
||||||
|
# preds, train_out = model(im, augment=augment)
|
||||||
|
|
||||||
|
# loss
|
||||||
|
with dt[2]:
|
||||||
|
if training:
|
||||||
|
loss += trainer.criterion(preds, labels) / images.shape[0]
|
||||||
|
|
||||||
|
# pre-process predictions
|
||||||
|
with dt[3]:
|
||||||
|
preds = self.preprocess_preds(preds)
|
||||||
|
|
||||||
|
self.update_metrics(preds, labels)
|
||||||
|
|
||||||
|
stats = self.get_stats()
|
||||||
|
self.check_stats(stats)
|
||||||
|
|
||||||
|
self.print_results()
|
||||||
|
|
||||||
|
# print speeds
|
||||||
|
if not training:
|
||||||
|
t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
|
||||||
|
# shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
|
||||||
|
self.logger.info(
|
||||||
|
'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t)
|
||||||
|
|
||||||
|
# TODO: implement save json
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def preprocess_batch(self, images, labels):
|
||||||
|
return images.to(self.device, non_blocking=True), labels.to(self.device)
|
||||||
|
|
||||||
|
def preprocess_preds(self, preds):
|
||||||
|
return preds
|
||||||
|
|
||||||
|
def init_metrics(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update_metrics(self, preds, targets):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def check_stats(self, stats):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def print_results(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_desc(self):
|
||||||
|
pass
|
@ -1,4 +1,4 @@
|
|||||||
from .general import WorkingDirectory, check_version, download, increment_path, save_yaml
|
from .general import Profile, WorkingDirectory, check_version, download, increment_path, save_yaml
|
||||||
from .torch_utils import LOCAL_RANK, RANK, WORLD_SIZE, DDP_model, select_device, torch_distributed_zero_first
|
from .torch_utils import LOCAL_RANK, RANK, WORLD_SIZE, DDP_model, select_device, torch_distributed_zero_first
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -8,6 +8,7 @@ __all__ = [
|
|||||||
"WorkingDirectory",
|
"WorkingDirectory",
|
||||||
"download",
|
"download",
|
||||||
"check_version",
|
"check_version",
|
||||||
|
"Profile",
|
||||||
# torch
|
# torch
|
||||||
"torch_distributed_zero_first",
|
"torch_distributed_zero_first",
|
||||||
"LOCAL_RANK",
|
"LOCAL_RANK",
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
model: null
|
||||||
|
data: null
|
||||||
train:
|
train:
|
||||||
epochs: 300
|
epochs: 300
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
@ -5,6 +5,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import time
|
||||||
import urllib
|
import urllib
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
@ -208,7 +209,7 @@ def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
def save_yaml(file='data.yaml', data={}):
|
def save_yaml(file='data.yaml', data=None):
|
||||||
# Single-line safe yaml saving
|
# Single-line safe yaml saving
|
||||||
with open(file, 'w') as f:
|
with open(file, 'w') as f:
|
||||||
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
||||||
@ -278,7 +279,6 @@ class WorkingDirectory(contextlib.ContextDecorator):
|
|||||||
|
|
||||||
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
||||||
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
|
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
|
||||||
from utils.general import LOGGER
|
|
||||||
|
|
||||||
file = Path(file)
|
file = Path(file)
|
||||||
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
|
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
|
||||||
@ -301,7 +301,6 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
|
|||||||
|
|
||||||
def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
|
def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
|
||||||
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
|
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
|
||||||
from utils.general import LOGGER
|
|
||||||
|
|
||||||
def github_assets(repository, version='latest'):
|
def github_assets(repository, version='latest'):
|
||||||
# Return GitHub repo tag and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
|
# Return GitHub repo tag and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
|
||||||
@ -351,3 +350,23 @@ def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
|
|||||||
def get_model(model: str):
|
def get_model(model: str):
|
||||||
# check for local weights
|
# check for local weights
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Profile(contextlib.ContextDecorator):
|
||||||
|
# YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
|
||||||
|
def __init__(self, t=0.0):
|
||||||
|
self.t = t
|
||||||
|
self.cuda = torch.cuda.is_available()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start = self.time()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
self.dt = self.time() - self.start # delta-time
|
||||||
|
self.t += self.dt # accumulate dt
|
||||||
|
|
||||||
|
def time(self):
|
||||||
|
if self.cuda:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
return time.time()
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
"""
|
"""
|
||||||
Model validation metrics
|
Model validation metrics
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,10 +4,8 @@ from pathlib import Path
|
|||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torch.hub as hub
|
|
||||||
import torchvision
|
import torchvision
|
||||||
import torchvision.transforms as T
|
from val import ClassificationValidator
|
||||||
from omegaconf import DictConfig, OmegaConf
|
|
||||||
|
|
||||||
from ultralytics.yolo import BaseTrainer, utils, v8
|
from ultralytics.yolo import BaseTrainer, utils, v8
|
||||||
from ultralytics.yolo.data import build_classification_dataloader
|
from ultralytics.yolo.data import build_classification_dataloader
|
||||||
@ -15,7 +13,7 @@ from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG
|
|||||||
|
|
||||||
|
|
||||||
# BaseTrainer python usage
|
# BaseTrainer python usage
|
||||||
class Trainer(BaseTrainer):
|
class ClassificationTrainer(BaseTrainer):
|
||||||
|
|
||||||
def get_dataset(self):
|
def get_dataset(self):
|
||||||
# temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
|
# temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
|
||||||
@ -55,13 +53,18 @@ class Trainer(BaseTrainer):
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def get_validator(self):
|
||||||
|
return ClassificationValidator(self.test_loader, self.device, logger=self.console) # validator
|
||||||
|
|
||||||
|
def criterion(self, preds, targets):
|
||||||
|
return torch.nn.functional.cross_entropy(preds, targets)
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0])
|
@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0])
|
||||||
def train(cfg):
|
def train(cfg):
|
||||||
model = "squeezenet1_0"
|
cfg.model = cfg.model or "squeezenet1_0"
|
||||||
dataset = "imagenette160" # or yolo.ClassificationDataset("mnist")
|
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
|
||||||
criterion = torch.nn.CrossEntropyLoss() # yolo.Loss object
|
trainer = ClassificationTrainer(cfg)
|
||||||
trainer = Trainer(model, dataset, criterion, config=cfg)
|
|
||||||
trainer.run()
|
trainer.run()
|
||||||
|
|
||||||
|
|
||||||
|
18
ultralytics/yolo/v8/classify/val.py
Normal file
18
ultralytics/yolo/v8/classify/val.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics import yolo
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationValidator(yolo.BaseValidator):
|
||||||
|
|
||||||
|
def init_metrics(self):
|
||||||
|
self.correct = torch.tensor([])
|
||||||
|
|
||||||
|
def update_metrics(self, preds, targets):
|
||||||
|
correct_in_batch = (targets[:, None] == preds).float()
|
||||||
|
self.correct = torch.cat((self.correct, correct_in_batch))
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
acc = torch.stack((self.correct[:, 0], self.correct.max(1).values), dim=1) # (top1, top5) accuracy
|
||||||
|
top1, top5 = acc.mean(0).tolist()
|
||||||
|
return {"top1": top1, "top5": top5, "fitness": top5}
|
Loading…
x
Reference in New Issue
Block a user