mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
update segment training (#57)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
d0b0fe2592
commit
3a241e4cea
@ -578,8 +578,8 @@ class Albumentations:
|
||||
# TODO: add supports of segments and keypoints
|
||||
if self.transform and random.random() < self.p:
|
||||
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
|
||||
labels["img"] = new["image"]
|
||||
labels["cls"] = np.array(new["class_labels"])
|
||||
labels["img"] = new["image"]
|
||||
labels["cls"] = np.array(new["class_labels"])
|
||||
labels["instances"].update(bboxes=bboxes)
|
||||
return labels
|
||||
|
||||
@ -635,7 +635,7 @@ class Format:
|
||||
def _format_img(self, img):
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
|
||||
img = torch.from_numpy(img)
|
||||
return img
|
||||
|
||||
|
@ -151,7 +151,7 @@ class BaseDataset(Dataset):
|
||||
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
||||
nb = bi[-1] + 1 # number of batches
|
||||
|
||||
s = np.array([x["shape"] for x in self.labels]) # hw
|
||||
s = np.array([x.pop("shape") for x in self.labels]) # hw
|
||||
ar = s[:, 0] / s[:, 1] # aspect ratio
|
||||
irect = ar.argsort()
|
||||
self.im_files = [self.im_files[i] for i in irect]
|
||||
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, dataloader, distributed
|
||||
|
||||
from ..utils import LOGGER
|
||||
from ..utils import LOGGER, colorstr
|
||||
from ..utils.torch_utils import torch_distributed_zero_first
|
||||
from .dataset import ClassificationDataset, YOLODataset
|
||||
from .utils import PIN_MEMORY, RANK
|
||||
@ -52,53 +52,36 @@ def seed_worker(worker_id):
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
# TODO: we can inject most args from a config file
|
||||
def build_dataloader(
|
||||
img_path,
|
||||
img_size, #
|
||||
batch_size, #
|
||||
single_cls=False, #
|
||||
hyp=None, #
|
||||
augment=False,
|
||||
cache=False, #
|
||||
image_weights=False, #
|
||||
stride=32,
|
||||
label_path=None,
|
||||
pad=0.0,
|
||||
rect=False,
|
||||
rank=-1,
|
||||
workers=8,
|
||||
prefix="",
|
||||
shuffle=False,
|
||||
use_segments=False,
|
||||
use_keypoints=False,
|
||||
):
|
||||
if rect and shuffle:
|
||||
def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank=-1, mode="train"):
|
||||
assert mode in ["train", "val"]
|
||||
shuffle = mode == "train"
|
||||
if cfg.rect and shuffle:
|
||||
LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||
shuffle = False
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = YOLODataset(
|
||||
img_path=img_path,
|
||||
img_size=img_size,
|
||||
batch_size=batch_size,
|
||||
label_path=label_path,
|
||||
augment=augment, # augmentation
|
||||
hyp=hyp,
|
||||
rect=rect, # rectangular batches
|
||||
cache=cache,
|
||||
single_cls=single_cls,
|
||||
img_size=cfg.img_size,
|
||||
batch_size=batch_size,
|
||||
augment=True if mode == "train" else False, # augmentation
|
||||
hyp=cfg.get("augment_hyp", None),
|
||||
rect=cfg.rect if mode == "train" else True, # rectangular batches
|
||||
cache=None if cfg.noval else cfg.get("cache", None),
|
||||
single_cls=cfg.get("single_cls", False),
|
||||
stride=int(stride),
|
||||
pad=pad,
|
||||
prefix=prefix,
|
||||
use_segments=use_segments,
|
||||
use_keypoints=use_keypoints,
|
||||
pad=0.0 if mode == "train" else 0.5,
|
||||
prefix=colorstr(f"{mode}: "),
|
||||
use_segments=cfg.task == "segment",
|
||||
use_keypoints=cfg.task == "keypoint",
|
||||
)
|
||||
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
nd = torch.cuda.device_count() # number of CUDA devices
|
||||
workers = cfg.workers if mode == "train" else cfg.workers * 2
|
||||
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
||||
loader = DataLoader if cfg.image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(6148914691236517205 + RANK)
|
||||
return (
|
||||
@ -118,6 +101,7 @@ def build_dataloader(
|
||||
|
||||
|
||||
# build classification
|
||||
# TODO: using cfg like `build_dataloader`
|
||||
def build_classification_dataloader(path,
|
||||
imgsz=224,
|
||||
batch_size=16,
|
||||
|
@ -24,11 +24,11 @@ from tqdm import tqdm
|
||||
import ultralytics.yolo.utils as utils
|
||||
import ultralytics.yolo.utils.callbacks as callbacks
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
|
||||
from ultralytics.yolo.utils.checks import print_args
|
||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
||||
from ultralytics.yolo.utils.modeling import get_model
|
||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle
|
||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
||||
|
||||
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
@ -48,13 +48,15 @@ class BaseTrainer:
|
||||
self.wdir = self.save_dir / 'weights' # weights dir
|
||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
||||
self.batch_size = self.args.batch_size
|
||||
self.epochs = self.args.epochs
|
||||
print_args(dict(self.args))
|
||||
|
||||
# Save run settings
|
||||
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
|
||||
|
||||
# device
|
||||
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size)
|
||||
self.device = utils.torch_utils.select_device(self.args.device, self.batch_size)
|
||||
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
||||
|
||||
# Model and Dataloaders.
|
||||
@ -73,10 +75,11 @@ class BaseTrainer:
|
||||
self.scheduler = None
|
||||
|
||||
# epoch level metrics
|
||||
self.metrics = {} # handle metrics returned by validator
|
||||
self.best_fitness = None
|
||||
self.fitness = None
|
||||
self.loss = None
|
||||
self.tloss = None
|
||||
self.csv = self.save_dir / 'results.csv'
|
||||
|
||||
for callback, func in callbacks.default_callbacks.items():
|
||||
self.add_callback(callback, func)
|
||||
@ -122,6 +125,7 @@ class BaseTrainer:
|
||||
if world_size > 1:
|
||||
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
|
||||
else:
|
||||
# self._do_train(int(os.getenv("RANK", -1)), world_size)
|
||||
self._do_train()
|
||||
|
||||
def _setup_ddp(self, rank, world_size):
|
||||
@ -129,21 +133,20 @@ class BaseTrainer:
|
||||
os.environ['MASTER_PORT'] = '9020'
|
||||
torch.cuda.set_device(rank)
|
||||
self.device = torch.device('cuda', rank)
|
||||
print(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
|
||||
self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
|
||||
|
||||
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
|
||||
self.model = self.model.to(self.device)
|
||||
self.model = DDP(self.model, device_ids=[rank])
|
||||
self.args.batch_size = self.args.batch_size // world_size
|
||||
|
||||
def _setup_train(self, rank):
|
||||
def _setup_train(self, rank, world_size):
|
||||
"""
|
||||
Builds dataloaders and optimizer on correct rank process
|
||||
"""
|
||||
# Optimizer
|
||||
self.set_model_attributes()
|
||||
accumulate = max(round(self.args.nbs / self.args.batch_size), 1) # accumulate loss before optimizing
|
||||
self.args.weight_decay *= self.args.batch_size * accumulate / self.args.nbs # scale weight_decay
|
||||
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||
self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||
self.optimizer = build_optimizer(model=self.model,
|
||||
name=self.args.optimizer,
|
||||
lr=self.args.lr0,
|
||||
@ -151,18 +154,21 @@ class BaseTrainer:
|
||||
decay=self.args.weight_decay)
|
||||
# Scheduler
|
||||
if self.args.cos_lr:
|
||||
self.lf = one_cycle(1, self.args.lrf, self.args.epochs) # cosine 1->hyp['lrf']
|
||||
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
||||
else:
|
||||
self.lf = lambda x: (1 - x / self.args.epochs) * (1.0 - self.args.lrf + self.args.lrf) # linear
|
||||
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
|
||||
# dataloaders
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank)
|
||||
batch_size = self.batch_size // world_size
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train")
|
||||
if rank in {0, -1}:
|
||||
print(" Creating testloader rank :", rank)
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=-1)
|
||||
self.validator = self.get_validator()
|
||||
print("created testloader :", rank)
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
|
||||
validator = self.get_validator()
|
||||
# init metric, for plot_results
|
||||
metric_keys = validator.metric_keys + self.label_loss_items(prefix="val")
|
||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
|
||||
self.validator = validator
|
||||
self.ema = ModelEMA(self.model)
|
||||
|
||||
def _do_train(self, rank=-1, world_size=1):
|
||||
@ -172,7 +178,7 @@ class BaseTrainer:
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
self.trigger_callbacks("before_train")
|
||||
self._setup_train(rank)
|
||||
self._setup_train(rank, world_size)
|
||||
|
||||
self.epoch = 0
|
||||
self.epoch_time = None
|
||||
@ -181,13 +187,17 @@ class BaseTrainer:
|
||||
nb = len(self.train_loader) # number of batches
|
||||
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
||||
last_opt_step = -1
|
||||
for epoch in range(self.args.epochs):
|
||||
for epoch in range(self.epochs):
|
||||
self.trigger_callbacks("on_epoch_start")
|
||||
self.model.train()
|
||||
if rank != -1:
|
||||
self.train_loader.sampler.set_epoch(epoch)
|
||||
pbar = enumerate(self.train_loader)
|
||||
if rank in {-1, 0}:
|
||||
self.console.info(self.progress_string())
|
||||
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT)
|
||||
self.tloss = None
|
||||
self.optimizer.zero_grad()
|
||||
for i, batch in pbar:
|
||||
self.trigger_callbacks("on_batch_start")
|
||||
# forward
|
||||
@ -197,7 +207,7 @@ class BaseTrainer:
|
||||
ni = i + nb * epoch
|
||||
if ni <= nw:
|
||||
xi = [0, nw] # x interp
|
||||
accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.args.batch_size]).round())
|
||||
self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
|
||||
for j, x in enumerate(self.optimizer.param_groups):
|
||||
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
||||
x['lr'] = np.interp(
|
||||
@ -207,37 +217,47 @@ class BaseTrainer:
|
||||
|
||||
preds = self.model(batch["img"])
|
||||
self.loss, self.loss_items = self.criterion(preds, batch)
|
||||
if rank != -1:
|
||||
self.loss *= world_size
|
||||
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
|
||||
else self.loss_items
|
||||
|
||||
# backward
|
||||
self.model.zero_grad(set_to_none=True)
|
||||
self.scaler.scale(self.loss).backward()
|
||||
|
||||
# optimize
|
||||
if ni - last_opt_step >= accumulate:
|
||||
if ni - last_opt_step >= self.accumulate:
|
||||
self.optimizer_step()
|
||||
last_opt_step = ni
|
||||
|
||||
# log
|
||||
mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
||||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
||||
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
|
||||
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||
if rank in {-1, 0}:
|
||||
pbar.set_description(
|
||||
(" {} " + "{:.3f} " * (1 + loss_len) + ' {} ').format(f'{epoch + 1}/{self.args.epochs}', mem,
|
||||
*losses, batch["img"].shape[-1]))
|
||||
('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
|
||||
(f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
|
||||
self.trigger_callbacks('on_batch_end')
|
||||
if self.args.plots and ni < 3:
|
||||
self.plot_training_samples(batch, ni)
|
||||
|
||||
lr = {f"lr{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||
self.scheduler.step()
|
||||
|
||||
if rank in [-1, 0]:
|
||||
# validation
|
||||
self.trigger_callbacks('on_val_start')
|
||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||
self.metrics, self.fitness = self.validate()
|
||||
final_epoch = (epoch + 1 == self.epochs)
|
||||
if not self.args.noval or final_epoch:
|
||||
self.metrics, self.fitness = self.validate()
|
||||
self.trigger_callbacks('on_val_end')
|
||||
log_vals = self.label_loss_items(self.tloss) | self.metrics | lr
|
||||
self.save_metrics(metrics=log_vals)
|
||||
|
||||
# save model
|
||||
if (not self.args.nosave) or (self.epoch + 1 == self.args.epochs):
|
||||
if (not self.args.nosave) or (self.epoch + 1 == self.epochs):
|
||||
self.save_model()
|
||||
self.trigger_callbacks('on_model_save')
|
||||
|
||||
@ -248,9 +268,15 @@ class BaseTrainer:
|
||||
|
||||
# TODO: termination condition
|
||||
|
||||
self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)")
|
||||
self.trigger_callbacks('on_train_end')
|
||||
if rank in [-1, 0]:
|
||||
# do the last evaluation with best.pt
|
||||
self.final_eval()
|
||||
if self.args.plots:
|
||||
self.plot_metrics()
|
||||
self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)")
|
||||
self.trigger_callbacks('on_train_end')
|
||||
dist.destroy_process_group() if world_size != 1 else None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def save_model(self):
|
||||
ckpt = {
|
||||
@ -306,7 +332,7 @@ class BaseTrainer:
|
||||
"fitness" metric.
|
||||
"""
|
||||
metrics = self.validator(self)
|
||||
fitness = metrics.get("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
if not self.best_fitness or self.best_fitness < fitness:
|
||||
self.best_fitness = self.fitness
|
||||
return metrics, fitness
|
||||
@ -339,12 +365,12 @@ class BaseTrainer:
|
||||
"""
|
||||
raise NotImplementedError("criterion function not implemented in trainer")
|
||||
|
||||
def label_loss_items(self, loss_items):
|
||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor
|
||||
"""
|
||||
# Not needed for classification but necessary for segmentation & detection
|
||||
return {"loss": loss_items}
|
||||
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""
|
||||
@ -355,6 +381,31 @@ class BaseTrainer:
|
||||
def build_targets(self, preds, targets):
|
||||
pass
|
||||
|
||||
def progress_string(self):
|
||||
return ""
|
||||
|
||||
# TODO: may need to put these following functions into callback
|
||||
def plot_training_samples(self, batch, ni):
|
||||
pass
|
||||
|
||||
def save_metrics(self, metrics):
|
||||
keys, vals = list(metrics.keys()), list(metrics.values())
|
||||
n = len(metrics) + 1 # number of cols
|
||||
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
|
||||
with open(self.csv, 'a') as f:
|
||||
f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')
|
||||
|
||||
def plot_metrics(self):
|
||||
pass
|
||||
|
||||
def final_eval(self):
|
||||
# TODO: need standalone evaluator to do this
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if f is self.best:
|
||||
self.console.info(f'\nValidating {f}...')
|
||||
|
||||
|
||||
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||
# TODO: 1. docstring with example? 2. Move this inside Trainer? or utils?
|
||||
@ -382,7 +433,7 @@ def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||
|
||||
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
||||
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
||||
LOGGER.info(f"optimizer: {type(optimizer).__name__}(lr={lr}) with parameter groups "
|
||||
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
|
||||
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
|
||||
return optimizer
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
@ -6,6 +7,7 @@ from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import TQDM_BAR_FORMAT
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device
|
||||
|
||||
@ -15,16 +17,17 @@ class BaseValidator:
|
||||
Base validator class.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader, pbar=None, logger=None, args=None):
|
||||
def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
self.logger = logger or logging.getLogger()
|
||||
self.args = args or OmegaConf.load(DEFAULT_CONFIG)
|
||||
self.device = select_device(self.args.device, dataloader.batch_size)
|
||||
self.save_dir = save_dir if save_dir is not None else \
|
||||
increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
||||
self.cuda = self.device.type != 'cpu'
|
||||
self.batch_i = None
|
||||
self.training = True
|
||||
self.loss = None
|
||||
|
||||
def __call__(self, trainer=None, model=None):
|
||||
"""
|
||||
@ -35,20 +38,22 @@ class BaseValidator:
|
||||
if self.training:
|
||||
model = trainer.ema.ema or trainer.model
|
||||
self.args.half &= self.device.type != 'cpu'
|
||||
# NOTE: half() inference in evaluation will make training stuck,
|
||||
# so I comment it out for now, I think we can reuse half mode after we add EMA.
|
||||
model = model.half() if self.args.half else model.float()
|
||||
loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
||||
else: # TODO: handle this when detectMultiBackend is supported
|
||||
assert model is not None, "Either trainer or model is needed for validation"
|
||||
# model = DetectMultiBacked(model)
|
||||
# TODO: implement init_model_attributes()
|
||||
|
||||
model.eval()
|
||||
|
||||
dt = Profile(), Profile(), Profile(), Profile()
|
||||
self.loss = 0
|
||||
n_batches = len(self.dataloader)
|
||||
desc = self.get_desc()
|
||||
bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
|
||||
# NOTE: keeping this `not self.training` in tqdm will eliminate pbar after finishing segmantation evaluation during training,
|
||||
# so I removed it, not sure if this will affect classification task cause I saw we use this arg in yolov5/classify/val.py.
|
||||
# bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
|
||||
bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT)
|
||||
self.init_metrics(de_parallel(model))
|
||||
with torch.no_grad():
|
||||
for batch_i, batch in enumerate(bar):
|
||||
@ -59,20 +64,23 @@ class BaseValidator:
|
||||
|
||||
# inference
|
||||
with dt[1]:
|
||||
preds = model(batch["img"].float())
|
||||
preds = model(batch["img"])
|
||||
# TODO: remember to add native augmentation support when implementing model, like:
|
||||
# preds, train_out = model(im, augment=augment)
|
||||
|
||||
# loss
|
||||
with dt[2]:
|
||||
if self.training:
|
||||
self.loss += trainer.criterion(preds, batch)[0]
|
||||
loss += trainer.criterion(preds, batch)[1]
|
||||
|
||||
# pre-process predictions
|
||||
with dt[3]:
|
||||
preds = self.postprocess(preds)
|
||||
|
||||
self.update_metrics(preds, batch)
|
||||
if self.args.plots and batch_i < 3:
|
||||
self.plot_val_samples(batch, batch_i)
|
||||
self.plot_predictions(batch, preds, batch_i)
|
||||
|
||||
stats = self.get_stats()
|
||||
self.check_stats(stats)
|
||||
@ -81,7 +89,7 @@ class BaseValidator:
|
||||
|
||||
# print speeds
|
||||
if not self.training:
|
||||
t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
|
||||
t = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
|
||||
# shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
|
||||
self.logger.info(
|
||||
'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t)
|
||||
@ -90,7 +98,8 @@ class BaseValidator:
|
||||
model.float()
|
||||
# TODO: implement save json
|
||||
|
||||
return stats
|
||||
return stats | trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val") \
|
||||
if self.training else stats
|
||||
|
||||
def preprocess(self, batch):
|
||||
return batch
|
||||
@ -105,7 +114,7 @@ class BaseValidator:
|
||||
pass
|
||||
|
||||
def get_stats(self):
|
||||
pass
|
||||
return {}
|
||||
|
||||
def check_stats(self, stats):
|
||||
pass
|
||||
@ -115,3 +124,14 @@ class BaseValidator:
|
||||
|
||||
def get_desc(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def metric_keys(self):
|
||||
return []
|
||||
|
||||
# TODO: may need to put these following functions into callback
|
||||
def plot_val_samples(self, batch, ni):
|
||||
pass
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
pass
|
||||
|
@ -3,6 +3,7 @@ import logging.config
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
# Constants
|
||||
@ -130,3 +131,13 @@ class TryExcept(contextlib.ContextDecorator):
|
||||
if value:
|
||||
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
|
||||
return True
|
||||
|
||||
|
||||
def threaded(func):
|
||||
# Multi-threads a target function and returns thread. Usage: @threaded decorator
|
||||
def wrapper(*args, **kwargs):
|
||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
return wrapper
|
||||
|
@ -26,11 +26,11 @@ deterministic: True
|
||||
local_rank: -1
|
||||
single_cls: False # train multi-class data as single-class
|
||||
image_weights: False # use weighted image selection for training
|
||||
shuffle: True
|
||||
rect: False # support rectangular training
|
||||
cos_lr: False # Use cosine LR scheduler
|
||||
overlap_mask: True # Segmentation masks overlap
|
||||
mask_ratio: 4 # Segmentation mask downsample ratio
|
||||
noval: False
|
||||
|
||||
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
||||
save_json: False
|
||||
@ -43,7 +43,7 @@ plots: False
|
||||
save_txt: False
|
||||
|
||||
# Hyperparameters ------------------------------------------------------------------------------------------------------
|
||||
lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||
lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
|
||||
momentum: 0.937 # SGD momentum/Adam beta1
|
||||
weight_decay: 0.0005 # optimizer weight decay 5e-4
|
||||
@ -59,22 +59,23 @@ iou_t: 0.20 # IoU training threshold
|
||||
anchor_t: 4.0 # anchor-multiple threshold
|
||||
# anchors: 3 # anchors per output layer (0 to ignore)
|
||||
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
|
||||
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
|
||||
degrees: 0.0 # image rotation (+/- deg)
|
||||
translate: 0.1 # image translation (+/- fraction)
|
||||
scale: 0.5 # image scale (+/- gain)
|
||||
shear: 0.0 # image shear (+/- deg)
|
||||
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
||||
flipud: 0.0 # image flip up-down (probability)
|
||||
fliplr: 0.5 # image flip left-right (probability)
|
||||
mosaic: 1.0 # image mosaic (probability)
|
||||
mixup: 0.0 # image mixup (probability)
|
||||
copy_paste: 0.0 # segment copy-paste (probability)
|
||||
label_smoothing: 0.0
|
||||
nbs: 64 # nominal batch size
|
||||
# anchors: 3
|
||||
augment_hyp:
|
||||
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
|
||||
degrees: 0.0 # image rotation (+/- deg)
|
||||
translate: 0.1 # image translation (+/- fraction)
|
||||
scale: 0.5 # image scale (+/- gain)
|
||||
shear: 0.0 # image shear (+/- deg)
|
||||
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
|
||||
flipud: 0.0 # image flip up-down (probability)
|
||||
fliplr: 0.5 # image flip left-right (probability)
|
||||
mosaic: 1.0 # image mosaic (probability)
|
||||
mixup: 0.0 # image mixup (probability)
|
||||
copy_paste: 0.0 # segment copy-paste (probability)
|
||||
|
||||
# Hydra configs --------------------------------------------------------------------------------------------------------
|
||||
hydra:
|
||||
|
@ -283,6 +283,50 @@ def smooth(y, f=0.05):
|
||||
return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
|
||||
|
||||
|
||||
def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
|
||||
# Precision-recall curve
|
||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
py = np.stack(py, axis=1)
|
||||
|
||||
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
||||
for i, y in enumerate(py.T):
|
||||
ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
|
||||
else:
|
||||
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
|
||||
|
||||
ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
|
||||
ax.set_xlabel('Recall')
|
||||
ax.set_ylabel('Precision')
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||
ax.set_title('Precision-Recall Curve')
|
||||
fig.savefig(save_dir, dpi=250)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
|
||||
# Metric-confidence curve
|
||||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
|
||||
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
||||
for i, y in enumerate(py):
|
||||
ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
|
||||
else:
|
||||
ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
|
||||
|
||||
y = smooth(py.mean(0), 0.05)
|
||||
ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
|
||||
ax.set_xlabel(xlabel)
|
||||
ax.set_ylabel(ylabel)
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||
ax.set_title(f'{ylabel}-Confidence Curve')
|
||||
fig.savefig(save_dir, dpi=250)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def compute_ap(recall, precision):
|
||||
""" Compute the average precision, given the recall and precision curves
|
||||
# Arguments
|
||||
@ -365,14 +409,11 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
|
||||
f1 = 2 * p * r / (p + r + eps)
|
||||
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
|
||||
names = dict(enumerate(names)) # to dict
|
||||
# TODO: plot
|
||||
'''
|
||||
if plot:
|
||||
plot_pr_curve(px, py, ap, Path(save_dir) / f'{prefix}PR_curve.png', names)
|
||||
plot_mc_curve(px, f1, Path(save_dir) / f'{prefix}F1_curve.png', names, ylabel='F1')
|
||||
plot_mc_curve(px, p, Path(save_dir) / f'{prefix}P_curve.png', names, ylabel='Precision')
|
||||
plot_mc_curve(px, r, Path(save_dir) / f'{prefix}R_curve.png', names, ylabel='Recall')
|
||||
'''
|
||||
|
||||
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
|
||||
p, r, f1 = p[:, i], r[:, i], f1[:, i]
|
||||
|
@ -1,12 +1,16 @@
|
||||
import contextlib
|
||||
import math
|
||||
from pathlib import Path
|
||||
from urllib.error import URLError
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR
|
||||
from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR, threaded
|
||||
|
||||
from .checks import check_font, check_requirements, is_ascii
|
||||
from .files import increment_path
|
||||
@ -179,3 +183,147 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
|
||||
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
|
||||
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
|
||||
return crop
|
||||
|
||||
|
||||
@threaded
|
||||
def plot_images_and_masks(images, batch_idx, cls, bboxes, masks, paths, confs=None, fname='images.jpg', names=None):
|
||||
# Plot image grid with labels
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = images.cpu().float().numpy()
|
||||
if isinstance(cls, torch.Tensor):
|
||||
cls = cls.cpu().numpy()
|
||||
if isinstance(bboxes, torch.Tensor):
|
||||
bboxes = bboxes.cpu().numpy()
|
||||
if isinstance(masks, torch.Tensor):
|
||||
masks = masks.cpu().numpy().astype(int)
|
||||
if isinstance(batch_idx, torch.Tensor):
|
||||
batch_idx = batch_idx.cpu().numpy()
|
||||
|
||||
max_size = 1920 # max image size
|
||||
max_subplots = 16 # max image subplots, i.e. 4x4
|
||||
bs, _, h, w = images.shape # batch size, _, height, width
|
||||
bs = min(bs, max_subplots) # limit plot images
|
||||
ns = np.ceil(bs ** 0.5) # number of subplots (square)
|
||||
if np.max(images[0]) <= 1:
|
||||
images *= 255 # de-normalise (optional)
|
||||
|
||||
# Build Image
|
||||
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
||||
for i, im in enumerate(images):
|
||||
if i == max_subplots: # if last batch has fewer images than we expect
|
||||
break
|
||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||
im = im.transpose(1, 2, 0)
|
||||
mosaic[y:y + h, x:x + w, :] = im
|
||||
|
||||
# Resize (optional)
|
||||
scale = max_size / ns / max(h, w)
|
||||
if scale < 1:
|
||||
h = math.ceil(scale * h)
|
||||
w = math.ceil(scale * w)
|
||||
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
|
||||
|
||||
# Annotate
|
||||
fs = int((h + w) * ns * 0.01) # font size
|
||||
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
|
||||
for i in range(i + 1):
|
||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
||||
if paths:
|
||||
annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
||||
if len(cls) > 0:
|
||||
idx = batch_idx == i
|
||||
|
||||
boxes = xywh2xyxy(bboxes[idx]).T
|
||||
classes = cls[idx].astype('int')
|
||||
labels = confs is None # labels if no conf column
|
||||
conf = None if labels else confs[idx] # check for confidence presence (label vs pred)
|
||||
|
||||
if boxes.shape[1]:
|
||||
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
|
||||
boxes[[0, 2]] *= w # scale to pixels
|
||||
boxes[[1, 3]] *= h
|
||||
elif scale < 1: # absolute coords need scale if image scales
|
||||
boxes *= scale
|
||||
boxes[[0, 2]] += x
|
||||
boxes[[1, 3]] += y
|
||||
for j, box in enumerate(boxes.T.tolist()):
|
||||
c = classes[j]
|
||||
color = colors(c)
|
||||
c = names[c] if names else c
|
||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
||||
label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
|
||||
annotator.box_label(box, label, color=color)
|
||||
|
||||
# Plot masks
|
||||
if len(masks):
|
||||
if masks.max() > 1.0: # mean that masks are overlap
|
||||
image_masks = masks[[i]] # (1, 640, 640)
|
||||
nl = idx.sum()
|
||||
index = np.arange(nl).reshape(nl, 1, 1) + 1
|
||||
image_masks = np.repeat(image_masks, nl, axis=0)
|
||||
image_masks = np.where(image_masks == index, 1.0, 0.0)
|
||||
else:
|
||||
image_masks = masks[idx]
|
||||
|
||||
im = np.asarray(annotator.im).copy()
|
||||
for j, box in enumerate(boxes.T.tolist()):
|
||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
||||
color = colors(classes[j])
|
||||
mh, mw = image_masks[j].shape
|
||||
if mh != h or mw != w:
|
||||
mask = image_masks[j].astype(np.uint8)
|
||||
mask = cv2.resize(mask, (w, h))
|
||||
mask = mask.astype(bool)
|
||||
else:
|
||||
mask = image_masks[j].astype(bool)
|
||||
with contextlib.suppress(Exception):
|
||||
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
|
||||
annotator.fromarray(im)
|
||||
annotator.im.save(fname) # save
|
||||
|
||||
|
||||
def plot_results_with_masks(file="path/to/results.csv", dir="", best=True):
|
||||
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
|
||||
save_dir = Path(file).parent if file else Path(dir)
|
||||
fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
|
||||
ax = ax.ravel()
|
||||
files = list(save_dir.glob("results*.csv"))
|
||||
assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
|
||||
for f in files:
|
||||
try:
|
||||
data = pd.read_csv(f)
|
||||
index = np.argmax(0.9 * data.values[:, 8] + 0.1 * data.values[:, 7] + 0.9 * data.values[:, 12] +
|
||||
0.1 * data.values[:, 11])
|
||||
s = [x.strip() for x in data.columns]
|
||||
x = data.values[:, 0]
|
||||
for i, j in enumerate([1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]):
|
||||
y = data.values[:, j]
|
||||
# y[y == 0] = np.nan # don't show zero values
|
||||
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=2)
|
||||
if best:
|
||||
# best
|
||||
ax[i].scatter(index, y[index], color="r", label=f"best:{index}", marker="*", linewidth=3)
|
||||
ax[i].set_title(s[j] + f"\n{round(y[index], 5)}")
|
||||
else:
|
||||
# last
|
||||
ax[i].scatter(x[-1], y[-1], color="r", label="last", marker="*", linewidth=3)
|
||||
ax[i].set_title(s[j] + f"\n{round(y[-1], 5)}")
|
||||
# if j in [8, 9, 10]: # share train and val loss y axes
|
||||
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
||||
except Exception as e:
|
||||
print(f"Warning: Plotting error for {f}: {e}")
|
||||
ax[1].legend()
|
||||
fig.savefig(save_dir / "results.png", dpi=200)
|
||||
plt.close()
|
||||
|
||||
|
||||
def output_to_target(output, max_det=300):
|
||||
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
|
||||
targets = []
|
||||
for i, o in enumerate(output):
|
||||
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
|
||||
j = torch.full((conf.shape[0], 1), i)
|
||||
targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
|
||||
targets = torch.cat(targets, 0).numpy()
|
||||
return targets[:, 0], targets[:, 1], targets[:, 2:6], targets[:, 6]
|
||||
|
@ -245,3 +245,19 @@ class ModelEMA:
|
||||
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
||||
# Update EMA attributes
|
||||
copy_attr(self.ema, model, include, exclude)
|
||||
|
||||
|
||||
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
|
||||
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
||||
x = torch.load(f, map_location=torch.device('cpu'))
|
||||
if x.get('ema'):
|
||||
x['model'] = x['ema'] # replace model with ema
|
||||
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
|
||||
x[k] = None
|
||||
x['epoch'] = -1
|
||||
x['model'].half() # to FP16
|
||||
for p in x['model'].parameters():
|
||||
p.requires_grad = False
|
||||
torch.save(x, s or f)
|
||||
mb = os.path.getsize(s or f) / 1E6 # filesize
|
||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||
|
@ -9,6 +9,9 @@ from ultralytics.yolo.utils.modeling.tasks import ClassificationModel
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def set_model_attributes(self):
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def load_model(self, model_cfg, weights, data):
|
||||
# TODO: why treat clf models as unique. We should have clf yamls?
|
||||
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
|
||||
@ -18,7 +21,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
ClassificationModel.reshape_outputs(model, data["nc"])
|
||||
return model
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=None, rank=0):
|
||||
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
|
||||
return build_classification_dataloader(path=dataset_path,
|
||||
imgsz=self.args.img_size,
|
||||
batch_size=batch_size,
|
||||
|
@ -23,3 +23,7 @@ class ClassificationValidator(BaseValidator):
|
||||
acc = torch.stack((self.correct[:, 0], self.correct.max(1).values), dim=1) # (top1, top5) accuracy
|
||||
top1, top5 = acc.mean(0).tolist()
|
||||
return {"top1": top1, "top5": top5, "fitness": top5}
|
||||
|
||||
@property
|
||||
def metric_keys(self):
|
||||
return ["top1", "top5"]
|
||||
|
@ -9,30 +9,18 @@ from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
||||
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
|
||||
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
|
||||
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
|
||||
from ultralytics.yolo.utils.plotting import plot_images_and_masks, plot_results_with_masks
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
|
||||
|
||||
# BaseTrainer python usage
|
||||
class SegmentationTrainer(BaseTrainer):
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size, rank=0):
|
||||
def get_dataloader(self, dataset_path, batch_size, mode="train", rank=0):
|
||||
# TODO: manage splits differently
|
||||
# calculate stride - check if model is initialized
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
return build_dataloader(
|
||||
img_path=dataset_path,
|
||||
img_size=self.args.img_size,
|
||||
batch_size=batch_size,
|
||||
single_cls=self.args.single_cls,
|
||||
cache=self.args.cache,
|
||||
image_weights=self.args.image_weights,
|
||||
stride=gs,
|
||||
rect=self.args.rect,
|
||||
rank=rank,
|
||||
workers=self.args.workers,
|
||||
shuffle=self.args.shuffle,
|
||||
use_segments=True,
|
||||
)[0]
|
||||
return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0]
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||
@ -58,7 +46,10 @@ class SegmentationTrainer(BaseTrainer):
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def get_validator(self):
|
||||
return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)
|
||||
return v8.segment.SegmentationValidator(self.test_loader,
|
||||
save_dir=self.save_dir,
|
||||
logger=self.console,
|
||||
args=self.args)
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
head = de_parallel(self.model).model[-1]
|
||||
@ -218,6 +209,8 @@ class SegmentationTrainer(BaseTrainer):
|
||||
else:
|
||||
mask_gti = masks[tidxs[i]][j]
|
||||
lseg += single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j])
|
||||
else:
|
||||
lseg += (proto * 0).sum()
|
||||
|
||||
obji = BCEobj(pi[..., 4], tobj)
|
||||
lobj += obji * balance[i] # obj loss
|
||||
@ -234,15 +227,33 @@ class SegmentationTrainer(BaseTrainer):
|
||||
loss = lbox + lobj + lcls + lseg
|
||||
return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach()
|
||||
|
||||
def label_loss_items(self, loss_items):
|
||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||
# We should just use named tensors here in future
|
||||
keys = ["lbox", "lseg", "lobj", "lcls"]
|
||||
return dict(zip(keys, loss_items))
|
||||
keys = [f"{prefix}/lbox", f"{prefix}/lseg", f"{prefix}/lobj", f"{prefix}/lcls"]
|
||||
return dict(zip(keys, loss_items)) if loss_items is not None else keys
|
||||
|
||||
def progress_string(self):
|
||||
return ('\n' + '%11s' * 7) % \
|
||||
('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Size')
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
images = batch["img"]
|
||||
masks = batch["masks"]
|
||||
cls = batch["cls"].squeeze(-1)
|
||||
bboxes = batch["bboxes"]
|
||||
paths = batch["im_file"]
|
||||
batch_idx = batch["batch_idx"]
|
||||
plot_images_and_masks(images,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes,
|
||||
masks,
|
||||
paths,
|
||||
fname=self.save_dir / f"train_batch{ni}.jpg")
|
||||
|
||||
def plot_metrics(self):
|
||||
plot_results_with_masks(file=self.csv) # save results.png
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
|
@ -6,23 +6,24 @@ import torch.nn.functional as F
|
||||
|
||||
from ultralytics.yolo.engine.validator import BaseValidator
|
||||
from ultralytics.yolo.utils import ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.checks import check_file, check_requirements
|
||||
from ultralytics.yolo.utils.files import yaml_load
|
||||
from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou,
|
||||
fitness_segmentation, mask_iou)
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images_and_masks
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
|
||||
|
||||
class SegmentationValidator(BaseValidator):
|
||||
|
||||
def __init__(self, dataloader, pbar=None, logger=None, args=None):
|
||||
super().__init__(dataloader, pbar, logger, args)
|
||||
def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
|
||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||
if self.args.save_json:
|
||||
check_requirements(['pycocotools'])
|
||||
self.process = ops.process_mask_upsample # more accurate
|
||||
else:
|
||||
self.process = ops.process_mask # faster
|
||||
self.data_dict = yaml_load(self.args.data) if self.args.data else None
|
||||
self.data_dict = yaml_load(check_file(self.args.data)) if self.args.data else None
|
||||
self.is_coco = False
|
||||
self.class_map = None
|
||||
self.targets = None
|
||||
@ -62,6 +63,7 @@ class SegmentationValidator(BaseValidator):
|
||||
self.loss = torch.zeros(4, device=self.device)
|
||||
self.jdict = []
|
||||
self.stats = []
|
||||
self.plot_masks = []
|
||||
|
||||
def get_desc(self):
|
||||
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
|
||||
@ -80,11 +82,10 @@ class SegmentationValidator(BaseValidator):
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
# Metrics
|
||||
plot_masks = [] # masks for plotting
|
||||
for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
|
||||
labels = self.targets[self.targets[:, 0] == si, 1:]
|
||||
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
|
||||
shape = batch["shape"][si]
|
||||
shape = batch["ori_shape"][si]
|
||||
# path = batch["shape"][si][0]
|
||||
correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device) # init
|
||||
@ -130,7 +131,7 @@ class SegmentationValidator(BaseValidator):
|
||||
|
||||
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
|
||||
if self.args.plots and self.batch_i < 3:
|
||||
plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
|
||||
self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
|
||||
|
||||
# TODO: Save/log
|
||||
'''
|
||||
@ -143,26 +144,14 @@ class SegmentationValidator(BaseValidator):
|
||||
# callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
|
||||
'''
|
||||
|
||||
# TODO Plot images
|
||||
'''
|
||||
if self.args.plots and self.batch_i < 3:
|
||||
if len(plot_masks):
|
||||
plot_masks = torch.cat(plot_masks, dim=0)
|
||||
plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names)
|
||||
plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, paths,
|
||||
save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
|
||||
'''
|
||||
|
||||
def get_stats(self):
|
||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
|
||||
if len(stats) and stats[0].any():
|
||||
# TODO: save_dir
|
||||
results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir='', names=self.names)
|
||||
results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir=self.save_dir, names=self.names)
|
||||
self.metrics.update(results)
|
||||
self.nt_per_class = np.bincount(stats[4].astype(int), minlength=self.nc) # number of targets per class
|
||||
keys = ["mp_bbox", "mr_bbox", "map50_bbox", "map_bbox", "mp_mask", "mr_mask", "map50_mask", "map_mask"]
|
||||
metrics = {"fitness": fitness_segmentation(np.array(self.metrics.mean_results()).reshape(1, -1))}
|
||||
metrics |= zip(keys, self.metrics.mean_results())
|
||||
metrics |= zip(self.metric_keys, self.metrics.mean_results())
|
||||
return metrics
|
||||
|
||||
def print_results(self):
|
||||
@ -177,9 +166,8 @@ class SegmentationValidator(BaseValidator):
|
||||
for i, c in enumerate(self.metrics.ap_class_index):
|
||||
self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
|
||||
|
||||
# plot TODO: save_dir
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.plot(save_dir='', names=list(self.names.values()))
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
|
||||
|
||||
def _process_batch(self, detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
||||
"""
|
||||
@ -217,3 +205,41 @@ class SegmentationValidator(BaseValidator):
|
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||
correct[matches[:, 1].astype(int), i] = True
|
||||
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
|
||||
|
||||
@property
|
||||
def metric_keys(self):
|
||||
return [
|
||||
"metrics/precision(B)",
|
||||
"metrics/recall(B)",
|
||||
"metrics/mAP_0.5(B)",
|
||||
"metrics/mAP_0.5:0.95(B)", # metrics
|
||||
"metrics/precision(M)",
|
||||
"metrics/recall(M)",
|
||||
"metrics/mAP_0.5(M)",
|
||||
"metrics/mAP_0.5:0.95(M)",]
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
images = batch["img"]
|
||||
masks = batch["masks"]
|
||||
cls = batch["cls"].squeeze(-1)
|
||||
bboxes = batch["bboxes"]
|
||||
paths = batch["im_file"]
|
||||
batch_idx = batch["batch_idx"]
|
||||
plot_images_and_masks(images,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes,
|
||||
masks,
|
||||
paths,
|
||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||
names=self.names)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
images = batch["img"]
|
||||
paths = batch["im_file"]
|
||||
if len(self.plot_masks):
|
||||
plot_masks = torch.cat(self.plot_masks, dim=0)
|
||||
batch_idx, cls, bboxes, conf = output_to_target(preds[0], max_det=15)
|
||||
plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf,
|
||||
self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred
|
||||
self.plot_masks.clear()
|
||||
|
Loading…
x
Reference in New Issue
Block a user