From 3a241e4ceacdb2c8a39f32490f2b87fd29710131 Mon Sep 17 00:00:00 2001
From: Laughing <61612323+Laughing-q@users.noreply.github.com>
Date: Tue, 29 Nov 2022 05:30:08 -0600
Subject: [PATCH] update segment training (#57)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
---
 ultralytics/yolo/data/augment.py            |   6 +-
 ultralytics/yolo/data/base.py               |   2 +-
 ultralytics/yolo/data/build.py              |  54 +++----
 ultralytics/yolo/engine/trainer.py          | 115 ++++++++++-----
 ultralytics/yolo/engine/validator.py        |  42 ++++--
 ultralytics/yolo/utils/__init__.py          |  11 ++
 ultralytics/yolo/utils/configs/default.yaml |  31 ++--
 ultralytics/yolo/utils/metrics.py           |  47 +++++-
 ultralytics/yolo/utils/plotting.py          | 150 +++++++++++++++++++-
 ultralytics/yolo/utils/torch_utils.py       |  16 +++
 ultralytics/yolo/v8/classify/train.py       |   5 +-
 ultralytics/yolo/v8/classify/val.py         |   4 +
 ultralytics/yolo/v8/segment/train.py        |  49 ++++---
 ultralytics/yolo/v8/segment/val.py          |  72 +++++++---
 14 files changed, 460 insertions(+), 144 deletions(-)

diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py
index c67b5c62..f30f630e 100644
--- a/ultralytics/yolo/data/augment.py
+++ b/ultralytics/yolo/data/augment.py
@@ -578,8 +578,8 @@ class Albumentations:
             # TODO: add supports of segments and keypoints
             if self.transform and random.random() < self.p:
                 new = self.transform(image=im, bboxes=bboxes, class_labels=cls)  # transformed
-            labels["img"] = new["image"]
-            labels["cls"] = np.array(new["class_labels"])
+                labels["img"] = new["image"]
+                labels["cls"] = np.array(new["class_labels"])
             labels["instances"].update(bboxes=bboxes)
         return labels
 
@@ -635,7 +635,7 @@ class Format:
     def _format_img(self, img):
         if len(img.shape) < 3:
             img = np.expand_dims(img, -1)
-        img = np.ascontiguousarray(img.transpose(2, 0, 1))
+        img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
         img = torch.from_numpy(img)
         return img
 
diff --git a/ultralytics/yolo/data/base.py b/ultralytics/yolo/data/base.py
index 32b9344d..9a11f6ed 100644
--- a/ultralytics/yolo/data/base.py
+++ b/ultralytics/yolo/data/base.py
@@ -151,7 +151,7 @@ class BaseDataset(Dataset):
         bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int)  # batch index
         nb = bi[-1] + 1  # number of batches
 
-        s = np.array([x["shape"] for x in self.labels])  # hw
+        s = np.array([x.pop("shape") for x in self.labels])  # hw
         ar = s[:, 0] / s[:, 1]  # aspect ratio
         irect = ar.argsort()
         self.im_files = [self.im_files[i] for i in irect]
diff --git a/ultralytics/yolo/data/build.py b/ultralytics/yolo/data/build.py
index 9bc4960c..3f3b881a 100644
--- a/ultralytics/yolo/data/build.py
+++ b/ultralytics/yolo/data/build.py
@@ -5,7 +5,7 @@ import numpy as np
 import torch
 from torch.utils.data import DataLoader, dataloader, distributed
 
-from ..utils import LOGGER
+from ..utils import LOGGER, colorstr
 from ..utils.torch_utils import torch_distributed_zero_first
 from .dataset import ClassificationDataset, YOLODataset
 from .utils import PIN_MEMORY, RANK
@@ -52,53 +52,36 @@ def seed_worker(worker_id):
     random.seed(worker_seed)
 
 
-# TODO: we can inject most args from a config file
-def build_dataloader(
-    img_path,
-    img_size,  #
-    batch_size,  #
-    single_cls=False,  #
-    hyp=None,  #
-    augment=False,
-    cache=False,  #
-    image_weights=False,  #
-    stride=32,
-    label_path=None,
-    pad=0.0,
-    rect=False,
-    rank=-1,
-    workers=8,
-    prefix="",
-    shuffle=False,
-    use_segments=False,
-    use_keypoints=False,
-):
-    if rect and shuffle:
+def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank=-1, mode="train"):
+    assert mode in ["train", "val"]
+    shuffle = mode == "train"
+    if cfg.rect and shuffle:
         LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
         shuffle = False
     with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
         dataset = YOLODataset(
             img_path=img_path,
-            img_size=img_size,
-            batch_size=batch_size,
             label_path=label_path,
-            augment=augment,  # augmentation
-            hyp=hyp,
-            rect=rect,  # rectangular batches
-            cache=cache,
-            single_cls=single_cls,
+            img_size=cfg.img_size,
+            batch_size=batch_size,
+            augment=True if mode == "train" else False,  # augmentation
+            hyp=cfg.get("augment_hyp", None),
+            rect=cfg.rect if mode == "train" else True,  # rectangular batches
+            cache=None if cfg.noval else cfg.get("cache", None),
+            single_cls=cfg.get("single_cls", False),
             stride=int(stride),
-            pad=pad,
-            prefix=prefix,
-            use_segments=use_segments,
-            use_keypoints=use_keypoints,
+            pad=0.0 if mode == "train" else 0.5,
+            prefix=colorstr(f"{mode}: "),
+            use_segments=cfg.task == "segment",
+            use_keypoints=cfg.task == "keypoint",
         )
 
     batch_size = min(batch_size, len(dataset))
     nd = torch.cuda.device_count()  # number of CUDA devices
+    workers = cfg.workers if mode == "train" else cfg.workers * 2
     nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])  # number of workers
     sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
-    loader = DataLoader if image_weights else InfiniteDataLoader  # only DataLoader allows for attribute updates
+    loader = DataLoader if cfg.image_weights else InfiniteDataLoader  # only DataLoader allows for attribute updates
     generator = torch.Generator()
     generator.manual_seed(6148914691236517205 + RANK)
     return (
@@ -118,6 +101,7 @@ def build_dataloader(
 
 
 # build classification
+# TODO: using cfg like `build_dataloader`
 def build_classification_dataloader(path,
                                     imgsz=224,
                                     batch_size=16,
diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py
index 8f1f7b8b..8f319876 100644
--- a/ultralytics/yolo/engine/trainer.py
+++ b/ultralytics/yolo/engine/trainer.py
@@ -24,11 +24,11 @@ from tqdm import tqdm
 import ultralytics.yolo.utils as utils
 import ultralytics.yolo.utils.callbacks as callbacks
 from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
-from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
+from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
 from ultralytics.yolo.utils.checks import print_args
 from ultralytics.yolo.utils.files import increment_path, save_yaml
 from ultralytics.yolo.utils.modeling import get_model
-from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle
+from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
 
 DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
 RANK = int(os.getenv('RANK', -1))
@@ -48,13 +48,15 @@ class BaseTrainer:
         self.wdir = self.save_dir / 'weights'  # weights dir
         self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
         self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'  # checkpoint paths
+        self.batch_size = self.args.batch_size
+        self.epochs = self.args.epochs
         print_args(dict(self.args))
 
         # Save run settings
         save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
 
         # device
-        self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size)
+        self.device = utils.torch_utils.select_device(self.args.device, self.batch_size)
         self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
 
         # Model and Dataloaders.
@@ -73,10 +75,11 @@ class BaseTrainer:
         self.scheduler = None
 
         # epoch level metrics
-        self.metrics = {}  # handle metrics returned by validator
         self.best_fitness = None
         self.fitness = None
         self.loss = None
+        self.tloss = None
+        self.csv = self.save_dir / 'results.csv'
 
         for callback, func in callbacks.default_callbacks.items():
             self.add_callback(callback, func)
@@ -122,6 +125,7 @@ class BaseTrainer:
         if world_size > 1:
             mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
         else:
+            # self._do_train(int(os.getenv("RANK", -1)), world_size)
             self._do_train()
 
     def _setup_ddp(self, rank, world_size):
@@ -129,21 +133,20 @@ class BaseTrainer:
         os.environ['MASTER_PORT'] = '9020'
         torch.cuda.set_device(rank)
         self.device = torch.device('cuda', rank)
-        print(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
+        self.console.info(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
 
         dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
         self.model = self.model.to(self.device)
         self.model = DDP(self.model, device_ids=[rank])
-        self.args.batch_size = self.args.batch_size // world_size
 
-    def _setup_train(self, rank):
+    def _setup_train(self, rank, world_size):
         """
         Builds dataloaders and optimizer on correct rank process
         """
         # Optimizer
         self.set_model_attributes()
-        accumulate = max(round(self.args.nbs / self.args.batch_size), 1)  # accumulate loss before optimizing
-        self.args.weight_decay *= self.args.batch_size * accumulate / self.args.nbs  # scale weight_decay
+        self.accumulate = max(round(self.args.nbs / self.batch_size), 1)  # accumulate loss before optimizing
+        self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs  # scale weight_decay
         self.optimizer = build_optimizer(model=self.model,
                                          name=self.args.optimizer,
                                          lr=self.args.lr0,
@@ -151,18 +154,21 @@ class BaseTrainer:
                                          decay=self.args.weight_decay)
         # Scheduler
         if self.args.cos_lr:
-            self.lf = one_cycle(1, self.args.lrf, self.args.epochs)  # cosine 1->hyp['lrf']
+            self.lf = one_cycle(1, self.args.lrf, self.epochs)  # cosine 1->hyp['lrf']
         else:
-            self.lf = lambda x: (1 - x / self.args.epochs) * (1.0 - self.args.lrf + self.args.lrf)  # linear
+            self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf  # linear
         self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
 
         # dataloaders
-        self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank)
+        batch_size = self.batch_size // world_size
+        self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=rank, mode="train")
         if rank in {0, -1}:
-            print(" Creating testloader rank :", rank)
-            self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=-1)
-            self.validator = self.get_validator()
-            print("created testloader :", rank)
+            self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode="val")
+            validator = self.get_validator()
+            # init metric, for plot_results
+            metric_keys = validator.metric_keys + self.label_loss_items(prefix="val")
+            self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
+            self.validator = validator
             self.ema = ModelEMA(self.model)
 
     def _do_train(self, rank=-1, world_size=1):
@@ -172,7 +178,7 @@ class BaseTrainer:
             self.model = self.model.to(self.device)
 
         self.trigger_callbacks("before_train")
-        self._setup_train(rank)
+        self._setup_train(rank, world_size)
 
         self.epoch = 0
         self.epoch_time = None
@@ -181,13 +187,17 @@ class BaseTrainer:
         nb = len(self.train_loader)  # number of batches
         nw = max(round(self.args.warmup_epochs * nb), 100)  # number of warmup iterations
         last_opt_step = -1
-        for epoch in range(self.args.epochs):
+        for epoch in range(self.epochs):
             self.trigger_callbacks("on_epoch_start")
             self.model.train()
+            if rank != -1:
+                self.train_loader.sampler.set_epoch(epoch)
             pbar = enumerate(self.train_loader)
             if rank in {-1, 0}:
+                self.console.info(self.progress_string())
                 pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT)
             self.tloss = None
+            self.optimizer.zero_grad()
             for i, batch in pbar:
                 self.trigger_callbacks("on_batch_start")
                 # forward
@@ -197,7 +207,7 @@ class BaseTrainer:
                 ni = i + nb * epoch
                 if ni <= nw:
                     xi = [0, nw]  # x interp
-                    accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.args.batch_size]).round())
+                    self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
                     for j, x in enumerate(self.optimizer.param_groups):
                         # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                         x['lr'] = np.interp(
@@ -207,37 +217,47 @@ class BaseTrainer:
 
                 preds = self.model(batch["img"])
                 self.loss, self.loss_items = self.criterion(preds, batch)
+                if rank != -1:
+                    self.loss *= world_size
                 self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
                                 else self.loss_items
 
                 # backward
-                self.model.zero_grad(set_to_none=True)
                 self.scaler.scale(self.loss).backward()
 
                 # optimize
-                if ni - last_opt_step >= accumulate:
+                if ni - last_opt_step >= self.accumulate:
                     self.optimizer_step()
                     last_opt_step = ni
 
                 # log
-                mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0)  # (GB)
+                mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G'  # (GB)
                 loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
                 losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
                 if rank in {-1, 0}:
                     pbar.set_description(
-                        (" {} " + "{:.3f}  " * (1 + loss_len) + ' {} ').format(f'{epoch + 1}/{self.args.epochs}', mem,
-                                                                               *losses, batch["img"].shape[-1]))
+                        ('%11s' * 2 + '%11.4g' * (2 + loss_len)) %
+                        (f'{epoch + 1}/{self.epochs}', mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]))
                     self.trigger_callbacks('on_batch_end')
+                    if self.args.plots and ni < 3:
+                        self.plot_training_samples(batch, ni)
+
+            lr = {f"lr{ir}": x['lr'] for ir, x in enumerate(self.optimizer.param_groups)}  # for loggers
+            self.scheduler.step()
 
             if rank in [-1, 0]:
                 # validation
                 self.trigger_callbacks('on_val_start')
                 self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
-                self.metrics, self.fitness = self.validate()
+                final_epoch = (epoch + 1 == self.epochs)
+                if not self.args.noval or final_epoch:
+                    self.metrics, self.fitness = self.validate()
                 self.trigger_callbacks('on_val_end')
+                log_vals = self.label_loss_items(self.tloss) | self.metrics | lr
+                self.save_metrics(metrics=log_vals)
 
                 # save model
-                if (not self.args.nosave) or (self.epoch + 1 == self.args.epochs):
+                if (not self.args.nosave) or (self.epoch + 1 == self.epochs):
                     self.save_model()
                     self.trigger_callbacks('on_model_save')
 
@@ -248,9 +268,15 @@ class BaseTrainer:
 
             # TODO: termination condition
 
-        self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)")
-        self.trigger_callbacks('on_train_end')
+        if rank in [-1, 0]:
+            # do the last evaluation with best.pt
+            self.final_eval()
+            if self.args.plots:
+                self.plot_metrics()
+            self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)")
+            self.trigger_callbacks('on_train_end')
         dist.destroy_process_group() if world_size != 1 else None
+        torch.cuda.empty_cache()
 
     def save_model(self):
         ckpt = {
@@ -306,7 +332,7 @@ class BaseTrainer:
         "fitness" metric.
         """
         metrics = self.validator(self)
-        fitness = metrics.get("fitness", -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found
+        fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy())  # use loss as fitness measure if not found
         if not self.best_fitness or self.best_fitness < fitness:
             self.best_fitness = self.fitness
         return metrics, fitness
@@ -339,12 +365,12 @@ class BaseTrainer:
         """
         raise NotImplementedError("criterion function not implemented in trainer")
 
-    def label_loss_items(self, loss_items):
+    def label_loss_items(self, loss_items=None, prefix="train"):
         """
         Returns a loss dict with labelled training loss items tensor
         """
         # Not needed for classification but necessary for segmentation & detection
-        return {"loss": loss_items}
+        return {"loss": loss_items} if loss_items is not None else ["loss"]
 
     def set_model_attributes(self):
         """
@@ -355,6 +381,31 @@ class BaseTrainer:
     def build_targets(self, preds, targets):
         pass
 
+    def progress_string(self):
+        return ""
+
+    # TODO: may need to put these following functions into callback
+    def plot_training_samples(self, batch, ni):
+        pass
+
+    def save_metrics(self, metrics):
+        keys, vals = list(metrics.keys()), list(metrics.values())
+        n = len(metrics) + 1  # number of cols
+        s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n')  # header
+        with open(self.csv, 'a') as f:
+            f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')
+
+    def plot_metrics(self):
+        pass
+
+    def final_eval(self):
+        # TODO: need standalone evaluator to do this
+        for f in self.last, self.best:
+            if f.exists():
+                strip_optimizer(f)  # strip optimizers
+                if f is self.best:
+                    self.console.info(f'\nValidating {f}...')
+
 
 def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
     # TODO: 1. docstring with example? 2. Move this inside Trainer? or utils?
@@ -382,7 +433,7 @@ def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
 
     optimizer.add_param_group({'params': g[0], 'weight_decay': decay})  # add g0 with weight_decay
     optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})  # add g1 (BatchNorm2d weights)
-    LOGGER.info(f"optimizer: {type(optimizer).__name__}(lr={lr}) with parameter groups "
+    LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
                 f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
     return optimizer
 
diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py
index 24a840ea..bc2fdf79 100644
--- a/ultralytics/yolo/engine/validator.py
+++ b/ultralytics/yolo/engine/validator.py
@@ -1,4 +1,5 @@
 import logging
+from pathlib import Path
 
 import torch
 from omegaconf import OmegaConf
@@ -6,6 +7,7 @@ from tqdm import tqdm
 
 from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
 from ultralytics.yolo.utils import TQDM_BAR_FORMAT
+from ultralytics.yolo.utils.files import increment_path
 from ultralytics.yolo.utils.ops import Profile
 from ultralytics.yolo.utils.torch_utils import de_parallel, select_device
 
@@ -15,16 +17,17 @@ class BaseValidator:
     Base validator class.
     """
 
-    def __init__(self, dataloader, pbar=None, logger=None, args=None):
+    def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
         self.dataloader = dataloader
         self.pbar = pbar
         self.logger = logger or logging.getLogger()
         self.args = args or OmegaConf.load(DEFAULT_CONFIG)
         self.device = select_device(self.args.device, dataloader.batch_size)
+        self.save_dir = save_dir if save_dir is not None else \
+                increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
         self.cuda = self.device.type != 'cpu'
         self.batch_i = None
         self.training = True
-        self.loss = None
 
     def __call__(self, trainer=None, model=None):
         """
@@ -35,20 +38,22 @@ class BaseValidator:
         if self.training:
             model = trainer.ema.ema or trainer.model
             self.args.half &= self.device.type != 'cpu'
-            # NOTE: half() inference in evaluation will make training stuck,
-            # so I comment it out for now, I think we can reuse half mode after we add EMA.
             model = model.half() if self.args.half else model.float()
+            loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
         else:  # TODO: handle this when detectMultiBackend is supported
             assert model is not None, "Either trainer or model is needed for validation"
             # model = DetectMultiBacked(model)
             # TODO: implement init_model_attributes()
 
         model.eval()
+
         dt = Profile(), Profile(), Profile(), Profile()
-        self.loss = 0
         n_batches = len(self.dataloader)
         desc = self.get_desc()
-        bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
+        # NOTE: keeping this `not self.training` in tqdm will eliminate pbar after finishing segmantation evaluation during training,
+        # so I removed it, not sure if this will affect classification task cause I saw we use this arg in yolov5/classify/val.py.
+        # bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
+        bar = tqdm(self.dataloader, desc, n_batches, bar_format=TQDM_BAR_FORMAT)
         self.init_metrics(de_parallel(model))
         with torch.no_grad():
             for batch_i, batch in enumerate(bar):
@@ -59,20 +64,23 @@ class BaseValidator:
 
                 # inference
                 with dt[1]:
-                    preds = model(batch["img"].float())
+                    preds = model(batch["img"])
                     # TODO: remember to add native augmentation support when implementing model, like:
                     #  preds, train_out = model(im, augment=augment)
 
                 # loss
                 with dt[2]:
                     if self.training:
-                        self.loss += trainer.criterion(preds, batch)[0]
+                        loss += trainer.criterion(preds, batch)[1]
 
                 # pre-process predictions
                 with dt[3]:
                     preds = self.postprocess(preds)
 
                 self.update_metrics(preds, batch)
+                if self.args.plots and batch_i < 3:
+                    self.plot_val_samples(batch, batch_i)
+                    self.plot_predictions(batch, preds, batch_i)
 
         stats = self.get_stats()
         self.check_stats(stats)
@@ -81,7 +89,7 @@ class BaseValidator:
 
         # print speeds
         if not self.training:
-            t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt)  # speeds per image
+            t = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt)  # speeds per image
             # shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
             self.logger.info(
                 'Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image at shape ' % t)
@@ -90,7 +98,8 @@ class BaseValidator:
             model.float()
         # TODO: implement save json
 
-        return stats
+        return stats | trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val") \
+                if self.training else stats
 
     def preprocess(self, batch):
         return batch
@@ -105,7 +114,7 @@ class BaseValidator:
         pass
 
     def get_stats(self):
-        pass
+        return {}
 
     def check_stats(self, stats):
         pass
@@ -115,3 +124,14 @@ class BaseValidator:
 
     def get_desc(self):
         pass
+
+    @property
+    def metric_keys(self):
+        return []
+
+    # TODO: may need to put these following functions into callback
+    def plot_val_samples(self, batch, ni):
+        pass
+
+    def plot_predictions(self, batch, preds, ni):
+        pass
diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py
index 0216ec37..c171a48b 100644
--- a/ultralytics/yolo/utils/__init__.py
+++ b/ultralytics/yolo/utils/__init__.py
@@ -3,6 +3,7 @@ import logging.config
 import os
 import platform
 import sys
+import threading
 from pathlib import Path
 
 # Constants
@@ -130,3 +131,13 @@ class TryExcept(contextlib.ContextDecorator):
         if value:
             print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
         return True
+
+
+def threaded(func):
+    # Multi-threads a target function and returns thread. Usage: @threaded decorator
+    def wrapper(*args, **kwargs):
+        thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
+        thread.start()
+        return thread
+
+    return wrapper
diff --git a/ultralytics/yolo/utils/configs/default.yaml b/ultralytics/yolo/utils/configs/default.yaml
index 99c1c106..fe50a3dd 100644
--- a/ultralytics/yolo/utils/configs/default.yaml
+++ b/ultralytics/yolo/utils/configs/default.yaml
@@ -26,11 +26,11 @@ deterministic: True
 local_rank: -1
 single_cls: False  # train multi-class data as single-class
 image_weights: False  # use weighted image selection for training
-shuffle: True
 rect: False  # support rectangular training
 cos_lr: False # Use cosine LR scheduler
 overlap_mask: True  # Segmentation masks overlap
 mask_ratio: 4  # Segmentation mask downsample ratio
+noval: False
 
 # Val/Test settings ----------------------------------------------------------------------------------------------------
 save_json: False
@@ -43,7 +43,7 @@ plots: False
 save_txt: False
 
 # Hyperparameters ------------------------------------------------------------------------------------------------------
-lr0: 0.001  # initial learning rate (SGD=1E-2, Adam=1E-3)
+lr0: 0.01  # initial learning rate (SGD=1E-2, Adam=1E-3)
 lrf: 0.01  # final OneCycleLR learning rate (lr0 * lrf)
 momentum: 0.937  # SGD momentum/Adam beta1
 weight_decay: 0.0005  # optimizer weight decay 5e-4
@@ -59,22 +59,23 @@ iou_t: 0.20  # IoU training threshold
 anchor_t: 4.0  # anchor-multiple threshold
 # anchors: 3  # anchors per output layer (0 to ignore)
 fl_gamma: 0.0  # focal loss gamma (efficientDet default gamma=1.5)
-hsv_h: 0.015  # image HSV-Hue augmentation (fraction)
-hsv_s: 0.7  # image HSV-Saturation augmentation (fraction)
-hsv_v: 0.4  # image HSV-Value augmentation (fraction)
-degrees: 0.0  # image rotation (+/- deg)
-translate: 0.1  # image translation (+/- fraction)
-scale: 0.5  # image scale (+/- gain)
-shear: 0.0  # image shear (+/- deg)
-perspective: 0.0  # image perspective (+/- fraction), range 0-0.001
-flipud: 0.0  # image flip up-down (probability)
-fliplr: 0.5  # image flip left-right (probability)
-mosaic: 1.0  # image mosaic (probability)
-mixup: 0.0  # image mixup (probability)
-copy_paste: 0.0  # segment copy-paste (probability)
 label_smoothing: 0.0
 nbs: 64 # nominal batch size
 # anchors: 3
+augment_hyp:
+  hsv_h: 0.015  # image HSV-Hue augmentation (fraction)
+  hsv_s: 0.7  # image HSV-Saturation augmentation (fraction)
+  hsv_v: 0.4  # image HSV-Value augmentation (fraction)
+  degrees: 0.0  # image rotation (+/- deg)
+  translate: 0.1  # image translation (+/- fraction)
+  scale: 0.5  # image scale (+/- gain)
+  shear: 0.0  # image shear (+/- deg)
+  perspective: 0.0  # image perspective (+/- fraction), range 0-0.001
+  flipud: 0.0  # image flip up-down (probability)
+  fliplr: 0.5  # image flip left-right (probability)
+  mosaic: 1.0  # image mosaic (probability)
+  mixup: 0.0  # image mixup (probability)
+  copy_paste: 0.0  # segment copy-paste (probability)
 
 # Hydra configs --------------------------------------------------------------------------------------------------------
 hydra:
diff --git a/ultralytics/yolo/utils/metrics.py b/ultralytics/yolo/utils/metrics.py
index 62bdcc99..1e05843b 100644
--- a/ultralytics/yolo/utils/metrics.py
+++ b/ultralytics/yolo/utils/metrics.py
@@ -283,6 +283,50 @@ def smooth(y, f=0.05):
     return np.convolve(yp, np.ones(nf) / nf, mode='valid')  # y-smoothed
 
 
+def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
+    # Precision-recall curve
+    fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
+    py = np.stack(py, axis=1)
+
+    if 0 < len(names) < 21:  # display per-class legend if < 21 classes
+        for i, y in enumerate(py.T):
+            ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}')  # plot(recall, precision)
+    else:
+        ax.plot(px, py, linewidth=1, color='grey')  # plot(recall, precision)
+
+    ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
+    ax.set_xlabel('Recall')
+    ax.set_ylabel('Precision')
+    ax.set_xlim(0, 1)
+    ax.set_ylim(0, 1)
+    ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
+    ax.set_title('Precision-Recall Curve')
+    fig.savefig(save_dir, dpi=250)
+    plt.close(fig)
+
+
+def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
+    # Metric-confidence curve
+    fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
+
+    if 0 < len(names) < 21:  # display per-class legend if < 21 classes
+        for i, y in enumerate(py):
+            ax.plot(px, y, linewidth=1, label=f'{names[i]}')  # plot(confidence, metric)
+    else:
+        ax.plot(px, py.T, linewidth=1, color='grey')  # plot(confidence, metric)
+
+    y = smooth(py.mean(0), 0.05)
+    ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
+    ax.set_xlabel(xlabel)
+    ax.set_ylabel(ylabel)
+    ax.set_xlim(0, 1)
+    ax.set_ylim(0, 1)
+    ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
+    ax.set_title(f'{ylabel}-Confidence Curve')
+    fig.savefig(save_dir, dpi=250)
+    plt.close(fig)
+
+
 def compute_ap(recall, precision):
     """ Compute the average precision, given the recall and precision curves
     # Arguments
@@ -365,14 +409,11 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
     f1 = 2 * p * r / (p + r + eps)
     names = [v for k, v in names.items() if k in unique_classes]  # list: only classes that have data
     names = dict(enumerate(names))  # to dict
-    # TODO: plot
-    '''
     if plot:
         plot_pr_curve(px, py, ap, Path(save_dir) / f'{prefix}PR_curve.png', names)
         plot_mc_curve(px, f1, Path(save_dir) / f'{prefix}F1_curve.png', names, ylabel='F1')
         plot_mc_curve(px, p, Path(save_dir) / f'{prefix}P_curve.png', names, ylabel='Precision')
         plot_mc_curve(px, r, Path(save_dir) / f'{prefix}R_curve.png', names, ylabel='Recall')
-    '''
 
     i = smooth(f1.mean(0), 0.1).argmax()  # max F1 index
     p, r, f1 = p[:, i], r[:, i], f1[:, i]
diff --git a/ultralytics/yolo/utils/plotting.py b/ultralytics/yolo/utils/plotting.py
index c983eee1..2bad2819 100644
--- a/ultralytics/yolo/utils/plotting.py
+++ b/ultralytics/yolo/utils/plotting.py
@@ -1,12 +1,16 @@
+import contextlib
+import math
 from pathlib import Path
 from urllib.error import URLError
 
 import cv2
+import matplotlib.pyplot as plt
 import numpy as np
+import pandas as pd
 import torch
 from PIL import Image, ImageDraw, ImageFont
 
-from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR
+from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR, threaded
 
 from .checks import check_font, check_requirements, is_ascii
 from .files import increment_path
@@ -179,3 +183,147 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
         # cv2.imwrite(f, crop)  # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
         Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0)  # save RGB
     return crop
+
+
+@threaded
+def plot_images_and_masks(images, batch_idx, cls, bboxes, masks, paths, confs=None, fname='images.jpg', names=None):
+    # Plot image grid with labels
+    if isinstance(images, torch.Tensor):
+        images = images.cpu().float().numpy()
+    if isinstance(cls, torch.Tensor):
+        cls = cls.cpu().numpy()
+    if isinstance(bboxes, torch.Tensor):
+        bboxes = bboxes.cpu().numpy()
+    if isinstance(masks, torch.Tensor):
+        masks = masks.cpu().numpy().astype(int)
+    if isinstance(batch_idx, torch.Tensor):
+        batch_idx = batch_idx.cpu().numpy()
+
+    max_size = 1920  # max image size
+    max_subplots = 16  # max image subplots, i.e. 4x4
+    bs, _, h, w = images.shape  # batch size, _, height, width
+    bs = min(bs, max_subplots)  # limit plot images
+    ns = np.ceil(bs ** 0.5)  # number of subplots (square)
+    if np.max(images[0]) <= 1:
+        images *= 255  # de-normalise (optional)
+
+    # Build Image
+    mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)  # init
+    for i, im in enumerate(images):
+        if i == max_subplots:  # if last batch has fewer images than we expect
+            break
+        x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin
+        im = im.transpose(1, 2, 0)
+        mosaic[y:y + h, x:x + w, :] = im
+
+    # Resize (optional)
+    scale = max_size / ns / max(h, w)
+    if scale < 1:
+        h = math.ceil(scale * h)
+        w = math.ceil(scale * w)
+        mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
+
+    # Annotate
+    fs = int((h + w) * ns * 0.01)  # font size
+    annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
+    for i in range(i + 1):
+        x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin
+        annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2)  # borders
+        if paths:
+            annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220))  # filenames
+        if len(cls) > 0:
+            idx = batch_idx == i
+
+            boxes = xywh2xyxy(bboxes[idx]).T
+            classes = cls[idx].astype('int')
+            labels = confs is None  # labels if no conf column
+            conf = None if labels else confs[idx]  # check for confidence presence (label vs pred)
+
+            if boxes.shape[1]:
+                if boxes.max() <= 1.01:  # if normalized with tolerance 0.01
+                    boxes[[0, 2]] *= w  # scale to pixels
+                    boxes[[1, 3]] *= h
+                elif scale < 1:  # absolute coords need scale if image scales
+                    boxes *= scale
+            boxes[[0, 2]] += x
+            boxes[[1, 3]] += y
+            for j, box in enumerate(boxes.T.tolist()):
+                c = classes[j]
+                color = colors(c)
+                c = names[c] if names else c
+                if labels or conf[j] > 0.25:  # 0.25 conf thresh
+                    label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
+                    annotator.box_label(box, label, color=color)
+
+            # Plot masks
+            if len(masks):
+                if masks.max() > 1.0:  # mean that masks are overlap
+                    image_masks = masks[[i]]  # (1, 640, 640)
+                    nl = idx.sum()
+                    index = np.arange(nl).reshape(nl, 1, 1) + 1
+                    image_masks = np.repeat(image_masks, nl, axis=0)
+                    image_masks = np.where(image_masks == index, 1.0, 0.0)
+                else:
+                    image_masks = masks[idx]
+
+                im = np.asarray(annotator.im).copy()
+                for j, box in enumerate(boxes.T.tolist()):
+                    if labels or conf[j] > 0.25:  # 0.25 conf thresh
+                        color = colors(classes[j])
+                        mh, mw = image_masks[j].shape
+                        if mh != h or mw != w:
+                            mask = image_masks[j].astype(np.uint8)
+                            mask = cv2.resize(mask, (w, h))
+                            mask = mask.astype(bool)
+                        else:
+                            mask = image_masks[j].astype(bool)
+                        with contextlib.suppress(Exception):
+                            im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
+                annotator.fromarray(im)
+    annotator.im.save(fname)  # save
+
+
+def plot_results_with_masks(file="path/to/results.csv", dir="", best=True):
+    # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
+    save_dir = Path(file).parent if file else Path(dir)
+    fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
+    ax = ax.ravel()
+    files = list(save_dir.glob("results*.csv"))
+    assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
+    for f in files:
+        try:
+            data = pd.read_csv(f)
+            index = np.argmax(0.9 * data.values[:, 8] + 0.1 * data.values[:, 7] + 0.9 * data.values[:, 12] +
+                              0.1 * data.values[:, 11])
+            s = [x.strip() for x in data.columns]
+            x = data.values[:, 0]
+            for i, j in enumerate([1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]):
+                y = data.values[:, j]
+                # y[y == 0] = np.nan  # don't show zero values
+                ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=2)
+                if best:
+                    # best
+                    ax[i].scatter(index, y[index], color="r", label=f"best:{index}", marker="*", linewidth=3)
+                    ax[i].set_title(s[j] + f"\n{round(y[index], 5)}")
+                else:
+                    # last
+                    ax[i].scatter(x[-1], y[-1], color="r", label="last", marker="*", linewidth=3)
+                    ax[i].set_title(s[j] + f"\n{round(y[-1], 5)}")
+                # if j in [8, 9, 10]:  # share train and val loss y axes
+                #     ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
+        except Exception as e:
+            print(f"Warning: Plotting error for {f}: {e}")
+    ax[1].legend()
+    fig.savefig(save_dir / "results.png", dpi=200)
+    plt.close()
+
+
+def output_to_target(output, max_det=300):
+    # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
+    targets = []
+    for i, o in enumerate(output):
+        box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
+        j = torch.full((conf.shape[0], 1), i)
+        targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
+    targets = torch.cat(targets, 0).numpy()
+    return targets[:, 0], targets[:, 1], targets[:, 2:6], targets[:, 6]
diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py
index 645b8605..dea42d8a 100644
--- a/ultralytics/yolo/utils/torch_utils.py
+++ b/ultralytics/yolo/utils/torch_utils.py
@@ -245,3 +245,19 @@ class ModelEMA:
     def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
         # Update EMA attributes
         copy_attr(self.ema, model, include, exclude)
+
+
+def strip_optimizer(f='best.pt', s=''):  # from utils.general import *; strip_optimizer()
+    # Strip optimizer from 'f' to finalize training, optionally save as 's'
+    x = torch.load(f, map_location=torch.device('cpu'))
+    if x.get('ema'):
+        x['model'] = x['ema']  # replace model with ema
+    for k in 'optimizer', 'best_fitness', 'ema', 'updates':  # keys
+        x[k] = None
+    x['epoch'] = -1
+    x['model'].half()  # to FP16
+    for p in x['model'].parameters():
+        p.requires_grad = False
+    torch.save(x, s or f)
+    mb = os.path.getsize(s or f) / 1E6  # filesize
+    LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py
index 4037b833..fa00ab29 100644
--- a/ultralytics/yolo/v8/classify/train.py
+++ b/ultralytics/yolo/v8/classify/train.py
@@ -9,6 +9,9 @@ from ultralytics.yolo.utils.modeling.tasks import ClassificationModel
 
 class ClassificationTrainer(BaseTrainer):
 
+    def set_model_attributes(self):
+        self.model.names = self.data["names"]
+
     def load_model(self, model_cfg, weights, data):
         # TODO: why treat clf models as unique. We should have clf yamls?
         if weights and not weights.__class__.__name__.startswith("yolo"):  # torchvision
@@ -18,7 +21,7 @@ class ClassificationTrainer(BaseTrainer):
         ClassificationModel.reshape_outputs(model, data["nc"])
         return model
 
-    def get_dataloader(self, dataset_path, batch_size=None, rank=0):
+    def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
         return build_classification_dataloader(path=dataset_path,
                                                imgsz=self.args.img_size,
                                                batch_size=batch_size,
diff --git a/ultralytics/yolo/v8/classify/val.py b/ultralytics/yolo/v8/classify/val.py
index 9fcfc6e3..ae5e5bdc 100644
--- a/ultralytics/yolo/v8/classify/val.py
+++ b/ultralytics/yolo/v8/classify/val.py
@@ -23,3 +23,7 @@ class ClassificationValidator(BaseValidator):
         acc = torch.stack((self.correct[:, 0], self.correct.max(1).values), dim=1)  # (top1, top5) accuracy
         top1, top5 = acc.mean(0).tolist()
         return {"top1": top1, "top5": top5, "fitness": top5}
+
+    @property
+    def metric_keys(self):
+        return ["top1", "top5"]
diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py
index 16f2ea97..0e1cb548 100644
--- a/ultralytics/yolo/v8/segment/train.py
+++ b/ultralytics/yolo/v8/segment/train.py
@@ -9,30 +9,18 @@ from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
 from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
 from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
 from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
+from ultralytics.yolo.utils.plotting import plot_images_and_masks, plot_results_with_masks
 from ultralytics.yolo.utils.torch_utils import de_parallel
 
 
 # BaseTrainer python usage
 class SegmentationTrainer(BaseTrainer):
 
-    def get_dataloader(self, dataset_path, batch_size, rank=0):
+    def get_dataloader(self, dataset_path, batch_size, mode="train", rank=0):
         # TODO: manage splits differently
         # calculate stride - check if model is initialized
         gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
-        return build_dataloader(
-            img_path=dataset_path,
-            img_size=self.args.img_size,
-            batch_size=batch_size,
-            single_cls=self.args.single_cls,
-            cache=self.args.cache,
-            image_weights=self.args.image_weights,
-            stride=gs,
-            rect=self.args.rect,
-            rank=rank,
-            workers=self.args.workers,
-            shuffle=self.args.shuffle,
-            use_segments=True,
-        )[0]
+        return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0]
 
     def preprocess_batch(self, batch):
         batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
@@ -58,7 +46,10 @@ class SegmentationTrainer(BaseTrainer):
         self.model.names = self.data["names"]
 
     def get_validator(self):
-        return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)
+        return v8.segment.SegmentationValidator(self.test_loader,
+                                                save_dir=self.save_dir,
+                                                logger=self.console,
+                                                args=self.args)
 
     def criterion(self, preds, batch):
         head = de_parallel(self.model).model[-1]
@@ -218,6 +209,8 @@ class SegmentationTrainer(BaseTrainer):
                     else:
                         mask_gti = masks[tidxs[i]][j]
                     lseg += single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j])
+            else:
+                lseg += (proto * 0).sum()
 
             obji = BCEobj(pi[..., 4], tobj)
             lobj += obji * balance[i]  # obj loss
@@ -234,15 +227,33 @@ class SegmentationTrainer(BaseTrainer):
         loss = lbox + lobj + lcls + lseg
         return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach()
 
-    def label_loss_items(self, loss_items):
+    def label_loss_items(self, loss_items=None, prefix="train"):
         # We should just use named tensors here in future
-        keys = ["lbox", "lseg", "lobj", "lcls"]
-        return dict(zip(keys, loss_items))
+        keys = [f"{prefix}/lbox", f"{prefix}/lseg", f"{prefix}/lobj", f"{prefix}/lcls"]
+        return dict(zip(keys, loss_items)) if loss_items is not None else keys
 
     def progress_string(self):
         return ('\n' + '%11s' * 7) % \
                ('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Size')
 
+    def plot_training_samples(self, batch, ni):
+        images = batch["img"]
+        masks = batch["masks"]
+        cls = batch["cls"].squeeze(-1)
+        bboxes = batch["bboxes"]
+        paths = batch["im_file"]
+        batch_idx = batch["batch_idx"]
+        plot_images_and_masks(images,
+                              batch_idx,
+                              cls,
+                              bboxes,
+                              masks,
+                              paths,
+                              fname=self.save_dir / f"train_batch{ni}.jpg")
+
+    def plot_metrics(self):
+        plot_results_with_masks(file=self.csv)  # save results.png
+
 
 @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
 def train(cfg):
diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py
index 372f3067..2669eca4 100644
--- a/ultralytics/yolo/v8/segment/val.py
+++ b/ultralytics/yolo/v8/segment/val.py
@@ -6,23 +6,24 @@ import torch.nn.functional as F
 
 from ultralytics.yolo.engine.validator import BaseValidator
 from ultralytics.yolo.utils import ops
-from ultralytics.yolo.utils.checks import check_requirements
+from ultralytics.yolo.utils.checks import check_file, check_requirements
 from ultralytics.yolo.utils.files import yaml_load
 from ultralytics.yolo.utils.metrics import (ConfusionMatrix, Metrics, ap_per_class_box_and_mask, box_iou,
                                             fitness_segmentation, mask_iou)
+from ultralytics.yolo.utils.plotting import output_to_target, plot_images_and_masks
 from ultralytics.yolo.utils.torch_utils import de_parallel
 
 
 class SegmentationValidator(BaseValidator):
 
-    def __init__(self, dataloader, pbar=None, logger=None, args=None):
-        super().__init__(dataloader, pbar, logger, args)
+    def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
+        super().__init__(dataloader, save_dir, pbar, logger, args)
         if self.args.save_json:
             check_requirements(['pycocotools'])
             self.process = ops.process_mask_upsample  # more accurate
         else:
             self.process = ops.process_mask  # faster
-        self.data_dict = yaml_load(self.args.data) if self.args.data else None
+        self.data_dict = yaml_load(check_file(self.args.data)) if self.args.data else None
         self.is_coco = False
         self.class_map = None
         self.targets = None
@@ -62,6 +63,7 @@ class SegmentationValidator(BaseValidator):
         self.loss = torch.zeros(4, device=self.device)
         self.jdict = []
         self.stats = []
+        self.plot_masks = []
 
     def get_desc(self):
         return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P",
@@ -80,11 +82,10 @@ class SegmentationValidator(BaseValidator):
 
     def update_metrics(self, preds, batch):
         # Metrics
-        plot_masks = []  # masks for plotting
         for si, (pred, proto) in enumerate(zip(preds[0], preds[1])):
             labels = self.targets[self.targets[:, 0] == si, 1:]
             nl, npr = labels.shape[0], pred.shape[0]  # number of labels, predictions
-            shape = batch["shape"][si]
+            shape = batch["ori_shape"][si]
             # path = batch["shape"][si][0]
             correct_masks = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)  # init
             correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)  # init
@@ -130,7 +131,7 @@ class SegmentationValidator(BaseValidator):
 
             pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
             if self.args.plots and self.batch_i < 3:
-                plot_masks.append(pred_masks[:15].cpu())  # filter top 15 to plot
+                self.plot_masks.append(pred_masks[:15].cpu())  # filter top 15 to plot
 
             # TODO: Save/log
             '''
@@ -143,26 +144,14 @@ class SegmentationValidator(BaseValidator):
             # callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
             '''
 
-        # TODO Plot images
-        '''
-        if self.args.plots and self.batch_i < 3:
-            if len(plot_masks):
-                plot_masks = torch.cat(plot_masks, dim=0)
-            plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names)
-            plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, paths,
-                                  save_dir / f'val_batch{batch_i}_pred.jpg', names)  # pred
-        '''
-
     def get_stats(self):
         stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)]  # to numpy
         if len(stats) and stats[0].any():
-            # TODO: save_dir
-            results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir='', names=self.names)
+            results = ap_per_class_box_and_mask(*stats, plot=self.args.plots, save_dir=self.save_dir, names=self.names)
             self.metrics.update(results)
         self.nt_per_class = np.bincount(stats[4].astype(int), minlength=self.nc)  # number of targets per class
-        keys = ["mp_bbox", "mr_bbox", "map50_bbox", "map_bbox", "mp_mask", "mr_mask", "map50_mask", "map_mask"]
         metrics = {"fitness": fitness_segmentation(np.array(self.metrics.mean_results()).reshape(1, -1))}
-        metrics |= zip(keys, self.metrics.mean_results())
+        metrics |= zip(self.metric_keys, self.metrics.mean_results())
         return metrics
 
     def print_results(self):
@@ -177,9 +166,8 @@ class SegmentationValidator(BaseValidator):
             for i, c in enumerate(self.metrics.ap_class_index):
                 self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
 
-        # plot TODO: save_dir
         if self.args.plots:
-            self.confusion_matrix.plot(save_dir='', names=list(self.names.values()))
+            self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
 
     def _process_batch(self, detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False):
         """
@@ -217,3 +205,41 @@ class SegmentationValidator(BaseValidator):
                     matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
                 correct[matches[:, 1].astype(int), i] = True
         return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
+
+    @property
+    def metric_keys(self):
+        return [
+            "metrics/precision(B)",
+            "metrics/recall(B)",
+            "metrics/mAP_0.5(B)",
+            "metrics/mAP_0.5:0.95(B)",  # metrics
+            "metrics/precision(M)",
+            "metrics/recall(M)",
+            "metrics/mAP_0.5(M)",
+            "metrics/mAP_0.5:0.95(M)",]
+
+    def plot_val_samples(self, batch, ni):
+        images = batch["img"]
+        masks = batch["masks"]
+        cls = batch["cls"].squeeze(-1)
+        bboxes = batch["bboxes"]
+        paths = batch["im_file"]
+        batch_idx = batch["batch_idx"]
+        plot_images_and_masks(images,
+                              batch_idx,
+                              cls,
+                              bboxes,
+                              masks,
+                              paths,
+                              fname=self.save_dir / f"val_batch{ni}_labels.jpg",
+                              names=self.names)
+
+    def plot_predictions(self, batch, preds, ni):
+        images = batch["img"]
+        paths = batch["im_file"]
+        if len(self.plot_masks):
+            plot_masks = torch.cat(self.plot_masks, dim=0)
+        batch_idx, cls, bboxes, conf = output_to_target(preds[0], max_det=15)
+        plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf,
+                              self.save_dir / f'val_batch{ni}_pred.jpg', self.names)  # pred
+        self.plot_masks.clear()