diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 0a768a57..0ab515e6 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -90,15 +90,16 @@ jobs:
       - name: Test detection
         shell: bash  # for Windows compatibility
         run: |
-          echo "TODO"
+          yolo task=detect mode=train model=yolov5n.yaml data=coco128.yaml epochs=1 img_size=64
+          yolo task=detect mode=val model=runs/exp/weights/last.pt img_size=64
       - name: Test segmentation
         shell: bash  # for Windows compatibility
         # TODO: redo val test without hardcoded weights
         run: |
           yolo task=segment mode=train model=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64
-          yolo task=segment mode=val model=runs/exp/weights/last.pt data=coco128-seg.yaml img_size=64
+          yolo task=segment mode=val model=runs/exp2/weights/last.pt data=coco128-seg.yaml img_size=64
       - name: Test classification
         shell: bash  # for Windows compatibility
         run: |
           yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 img_size=32
-          yolo task=classify mode=val model=runs/exp2/weights/last.pt data=mnist160
\ No newline at end of file
+          yolo task=classify mode=val model=runs/exp3/weights/last.pt data=mnist160
diff --git a/ultralytics/yolo/utils/metrics.py b/ultralytics/yolo/utils/metrics.py
index 1e05843b..1f80b92c 100644
--- a/ultralytics/yolo/utils/metrics.py
+++ b/ultralytics/yolo/utils/metrics.py
@@ -459,14 +459,14 @@ def ap_per_class_box_and_mask(
         "boxes": {
             "p": results_boxes[0],
             "r": results_boxes[1],
-            "ap": results_boxes[3],
             "f1": results_boxes[2],
+            "ap": results_boxes[3],
             "ap_class": results_boxes[4]},
         "masks": {
             "p": results_masks[0],
             "r": results_masks[1],
-            "ap": results_masks[3],
             "f1": results_masks[2],
+            "ap": results_masks[3],
             "ap_class": results_masks[4]}}
     return results
 
@@ -547,7 +547,7 @@ class Metric:
         Args:
             results: tuple(p, r, ap, f1, ap_class)
         """
-        p, r, all_ap, f1, ap_class_index = results
+        p, r, f1, all_ap, ap_class_index = results
         self.p = p
         self.r = r
         self.all_ap = all_ap
diff --git a/ultralytics/yolo/utils/plotting.py b/ultralytics/yolo/utils/plotting.py
index 2bad2819..cdd2fd21 100644
--- a/ultralytics/yolo/utils/plotting.py
+++ b/ultralytics/yolo/utils/plotting.py
@@ -186,7 +186,15 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
 
 
 @threaded
-def plot_images_and_masks(images, batch_idx, cls, bboxes, masks, paths, confs=None, fname='images.jpg', names=None):
+def plot_images_and_masks(images,
+                          batch_idx,
+                          cls,
+                          bboxes,
+                          masks,
+                          confs=None,
+                          paths=None,
+                          fname='images.jpg',
+                          names=None):
     # Plot image grid with labels
     if isinstance(images, torch.Tensor):
         images = images.cpu().float().numpy()
@@ -327,3 +335,99 @@ def output_to_target(output, max_det=300):
         targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
     targets = torch.cat(targets, 0).numpy()
     return targets[:, 0], targets[:, 1], targets[:, 2:6], targets[:, 6]
+
+
+@threaded
+def plot_images(images, batch_idx, cls, bboxes, confs=None, paths=None, fname='images.jpg', names=None):
+    # Plot image grid with labels
+    if isinstance(images, torch.Tensor):
+        images = images.cpu().float().numpy()
+    if isinstance(cls, torch.Tensor):
+        cls = cls.cpu().numpy()
+    if isinstance(bboxes, torch.Tensor):
+        bboxes = bboxes.cpu().numpy()
+    if isinstance(batch_idx, torch.Tensor):
+        batch_idx = batch_idx.cpu().numpy()
+
+    max_size = 1920  # max image size
+    max_subplots = 16  # max image subplots, i.e. 4x4
+    bs, _, h, w = images.shape  # batch size, _, height, width
+    bs = min(bs, max_subplots)  # limit plot images
+    ns = np.ceil(bs ** 0.5)  # number of subplots (square)
+    if np.max(images[0]) <= 1:
+        images *= 255  # de-normalise (optional)
+
+    # Build Image
+    mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8)  # init
+    for i, im in enumerate(images):
+        if i == max_subplots:  # if last batch has fewer images than we expect
+            break
+        x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin
+        im = im.transpose(1, 2, 0)
+        mosaic[y:y + h, x:x + w, :] = im
+
+    # Resize (optional)
+    scale = max_size / ns / max(h, w)
+    if scale < 1:
+        h = math.ceil(scale * h)
+        w = math.ceil(scale * w)
+        mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
+
+    # Annotate
+    fs = int((h + w) * ns * 0.01)  # font size
+    annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
+    for i in range(i + 1):
+        x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin
+        annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2)  # borders
+        if paths:
+            annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220))  # filenames
+        if len(cls) > 0:
+            idx = batch_idx == i
+
+            boxes = xywh2xyxy(bboxes[idx]).T
+            classes = cls[idx].astype('int')
+            labels = confs is None  # labels if no conf column
+            conf = None if labels else confs[idx]  # check for confidence presence (label vs pred)
+
+            if boxes.shape[1]:
+                if boxes.max() <= 1.01:  # if normalized with tolerance 0.01
+                    boxes[[0, 2]] *= w  # scale to pixels
+                    boxes[[1, 3]] *= h
+                elif scale < 1:  # absolute coords need scale if image scales
+                    boxes *= scale
+            boxes[[0, 2]] += x
+            boxes[[1, 3]] += y
+            for j, box in enumerate(boxes.T.tolist()):
+                c = classes[j]
+                color = colors(c)
+                c = names[c] if names else c
+                if labels or conf[j] > 0.25:  # 0.25 conf thresh
+                    label = f'{c}' if labels else f'{c} {conf[j]:.1f}'
+                    annotator.box_label(box, label, color=color)
+    annotator.im.save(fname)  # save
+
+
+def plot_results(file='path/to/results.csv', dir=''):
+    # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
+    save_dir = Path(file).parent if file else Path(dir)
+    fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
+    ax = ax.ravel()
+    files = list(save_dir.glob('results*.csv'))
+    assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
+    for f in files:
+        try:
+            data = pd.read_csv(f)
+            s = [x.strip() for x in data.columns]
+            x = data.values[:, 0]
+            for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
+                y = data.values[:, j].astype('float')
+                # y[y == 0] = np.nan  # don't show zero values
+                ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
+                ax[i].set_title(s[j], fontsize=12)
+                # if j in [8, 9, 10]:  # share train and val loss y axes
+                #     ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
+        except Exception as e:
+            print(f'Warning: Plotting error for {f}: {e}')
+    ax[1].legend()
+    fig.savefig(save_dir / 'results.png', dpi=200)
+    plt.close()
diff --git a/ultralytics/yolo/v8/__init__.py b/ultralytics/yolo/v8/__init__.py
index a18b41ad..cec773ed 100644
--- a/ultralytics/yolo/v8/__init__.py
+++ b/ultralytics/yolo/v8/__init__.py
@@ -1,7 +1,7 @@
 from pathlib import Path
 
-from ultralytics.yolo.v8 import classify, segment
+from ultralytics.yolo.v8 import classify, detect, segment
 
 ROOT = Path(__file__).parents[0]  # yolov8 ROOT
 
-__all__ = ["classify", "segment"]
+__all__ = ["classify", "segment", "detect"]
diff --git a/ultralytics/yolo/v8/detect/__init__.py b/ultralytics/yolo/v8/detect/__init__.py
new file mode 100644
index 00000000..edce22ae
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/__init__.py
@@ -0,0 +1,2 @@
+from ultralytics.yolo.v8.detect.train import DetectionTrainer, train
+from ultralytics.yolo.v8.detect.val import DetectionValidator, val
diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py
new file mode 100644
index 00000000..22e9e36e
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/train.py
@@ -0,0 +1,209 @@
+import hydra
+import torch
+import torch.nn as nn
+
+from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
+from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
+from ultralytics.yolo.utils.modeling.tasks import DetectionModel
+from ultralytics.yolo.utils.plotting import plot_images, plot_results
+from ultralytics.yolo.utils.torch_utils import de_parallel
+
+from ..segment import SegmentationTrainer
+from .val import DetectionValidator
+
+
+# BaseTrainer python usage
+class DetectionTrainer(SegmentationTrainer):
+
+    def load_model(self, model_cfg, weights, data):
+        model = DetectionModel(model_cfg or weights["model"].yaml,
+                               ch=3,
+                               nc=data["nc"],
+                               anchors=self.args.get("anchors"))
+        if weights:
+            model.load(weights)
+        for _, v in model.named_parameters():
+            v.requires_grad = True  # train all layers
+        return model
+
+    def get_validator(self):
+        return DetectionValidator(self.test_loader, save_dir=self.save_dir, logger=self.console, args=self.args)
+
+    def criterion(self, preds, batch):
+        head = de_parallel(self.model).model[-1]
+        sort_obj_iou = False
+        autobalance = False
+
+        # init losses
+        BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.cls_pw], device=self.device))
+        BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([self.args.obj_pw], device=self.device))
+
+        # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
+        cp, cn = smooth_BCE(eps=self.args.label_smoothing)  # positive, negative BCE targets
+
+        # Focal loss
+        g = self.args.fl_gamma
+        if self.args.fl_gamma > 0:
+            BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
+
+        balance = {3: [4.0, 1.0, 0.4]}.get(head.nl, [4.0, 1.0, 0.25, 0.06, 0.02])  # P3-P7
+        ssi = list(head.stride).index(16) if autobalance else 0  # stride 16 index
+        BCEcls, BCEobj, gr, autobalance = BCEcls, BCEobj, 1.0, autobalance
+
+        def build_targets(p, targets):
+            # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
+            nonlocal head
+            na, nt = head.na, targets.shape[0]  # number of anchors, targets
+            tcls, tbox, indices, anch = [], [], [], []
+            gain = torch.ones(7, device=self.device)  # normalized to gridspace gain
+            ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt)
+            targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2)  # append anchor indices
+
+            g = 0.5  # bias
+            off = torch.tensor(
+                [
+                    [0, 0],
+                    [1, 0],
+                    [0, 1],
+                    [-1, 0],
+                    [0, -1],  # j,k,l,m
+                    # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm
+                ],
+                device=self.device).float() * g  # offsets
+
+            for i in range(head.nl):
+                anchors, shape = head.anchors[i], p[i].shape
+                gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]]  # xyxy gain
+
+                # Match targets to anchors
+                t = targets * gain  # shape(3,n,7)
+                if nt:
+                    # Matches
+                    r = t[..., 4:6] / anchors[:, None]  # wh ratio
+                    j = torch.max(r, 1 / r).max(2)[0] < self.args.anchor_t  # compare
+                    # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
+                    t = t[j]  # filter
+
+                    # Offsets
+                    gxy = t[:, 2:4]  # grid xy
+                    gxi = gain[[2, 3]] - gxy  # inverse
+                    j, k = ((gxy % 1 < g) & (gxy > 1)).T
+                    l, m = ((gxi % 1 < g) & (gxi > 1)).T
+                    j = torch.stack((torch.ones_like(j), j, k, l, m))
+                    t = t.repeat((5, 1, 1))[j]
+                    offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
+                else:
+                    t = targets[0]
+                    offsets = 0
+
+                # Define
+                bc, gxy, gwh, a = t.chunk(4, 1)  # (image, class), grid xy, grid wh, anchors
+                a, (b, c) = a.long().view(-1), bc.long().T  # anchors, image, class
+                gij = (gxy - offsets).long()
+                gi, gj = gij.T  # grid indices
+
+                # Append
+                indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1)))  # image, anchor, grid
+                tbox.append(torch.cat((gxy - gij, gwh), 1))  # box
+                anch.append(anchors[a])  # anchors
+                tcls.append(c)  # class
+
+            return tcls, tbox, indices, anch
+
+        if len(preds) == 2:  # eval
+            _, p = preds
+        else:  # len(3) train
+            p = preds
+
+        targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
+        targets = targets.to(self.device)
+
+        lcls = torch.zeros(1, device=self.device)
+        lbox = torch.zeros(1, device=self.device)
+        lobj = torch.zeros(1, device=self.device)
+        tcls, tbox, indices, anchors = build_targets(p, targets)
+
+        # Losses
+        for i, pi in enumerate(p):  # layer index, layer predictions
+            b, a, gj, gi = indices[i]  # image, anchor, gridy, gridx
+            tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device)  # target obj
+            bs = tobj.shape[0]
+            n = b.shape[0]  # number of targets
+            if n:
+                pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, head.nc), 1)  # subset of predictions
+
+                # Box regression
+                pxy = pxy.sigmoid() * 2 - 0.5
+                pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
+                pbox = torch.cat((pxy, pwh), 1)  # predicted box
+                iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze()  # iou(prediction, target)
+                lbox += (1.0 - iou).mean()  # iou loss
+
+                # Objectness
+                iou = iou.detach().clamp(0).type(tobj.dtype)
+                if sort_obj_iou:
+                    j = iou.argsort()
+                    b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]
+                if gr < 1:
+                    iou = (1.0 - gr) + gr * iou
+                tobj[b, a, gj, gi] = iou  # iou ratio
+
+                # Classification
+                if head.nc > 1:  # cls loss (only if multiple classes)
+                    t = torch.full_like(pcls, cn, device=self.device)  # targets
+                    t[range(n), tcls[i]] = cp
+                    lcls += BCEcls(pcls, t)  # BCE
+
+            obji = BCEobj(pi[..., 4], tobj)
+            lobj += obji * balance[i]  # obj loss
+            if autobalance:
+                balance[i] = balance[i] * 0.9999 + 0.0001 / obji.detach().item()
+
+        if autobalance:
+            balance = [x / balance[ssi] for x in balance]
+        lbox *= self.args.box
+        lobj *= self.args.obj
+        lcls *= self.args.cls
+
+        loss = lbox + lobj + lcls
+        return loss * bs, torch.cat((lbox, lobj, lcls)).detach()
+
+    # TODO: improve from API users perspective
+    def label_loss_items(self, loss_items=None, prefix="train"):
+        # We should just use named tensors here in future
+        keys = [f"{prefix}/lbox", f"{prefix}/lobj", f"{prefix}/lcls"]
+        return dict(zip(keys, loss_items)) if loss_items is not None else keys
+
+    def progress_string(self):
+        return ('\n' + '%11s' * 6) % \
+               ('Epoch', 'GPU_mem', 'box_loss', 'obj_loss', 'cls_loss', 'Size')
+
+    def plot_training_samples(self, batch, ni):
+        images = batch["img"]
+        cls = batch["cls"].squeeze(-1)
+        bboxes = batch["bboxes"]
+        paths = batch["im_file"]
+        batch_idx = batch["batch_idx"]
+        plot_images(images, batch_idx, cls, bboxes, paths=paths, fname=self.save_dir / f"train_batch{ni}.jpg")
+
+    def plot_metrics(self):
+        plot_results(file=self.csv)  # save results.png
+
+
+@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
+def train(cfg):
+    cfg.model = cfg.model or "models/yolov5n.yaml"
+    cfg.data = cfg.data or "coco128.yaml"  # or yolo.ClassificationDataset("mnist")
+    trainer = DetectionTrainer(cfg)
+    trainer.train()
+
+
+if __name__ == "__main__":
+    """
+    CLI usage:
+    python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-segments epochs=100 img_size=640
+
+    TODO:
+    Direct cli support, i.e, yolov8 classify_train args.epochs 10
+    """
+    train()
diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py
new file mode 100644
index 00000000..63e9e46f
--- /dev/null
+++ b/ultralytics/yolo/v8/detect/val.py
@@ -0,0 +1,218 @@
+import os
+
+import hydra
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from ultralytics.yolo.data import build_dataloader
+from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
+from ultralytics.yolo.engine.validator import BaseValidator
+from ultralytics.yolo.utils import ops
+from ultralytics.yolo.utils.checks import check_file, check_requirements
+from ultralytics.yolo.utils.files import yaml_load
+from ultralytics.yolo.utils.metrics import ConfusionMatrix, Metric, ap_per_class, box_iou, fitness_detection
+from ultralytics.yolo.utils.plotting import output_to_target, plot_images
+from ultralytics.yolo.utils.torch_utils import de_parallel
+
+
+class DetectionValidator(BaseValidator):
+
+    def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
+        super().__init__(dataloader, save_dir, pbar, logger, args)
+        if self.args.save_json:
+            check_requirements(['pycocotools'])
+            self.process = ops.process_mask_upsample  # more accurate
+        else:
+            self.process = ops.process_mask  # faster
+        self.data_dict = yaml_load(check_file(self.args.data)) if self.args.data else None
+        self.is_coco = False
+        self.class_map = None
+        self.targets = None
+
+    def preprocess(self, batch):
+        batch["img"] = batch["img"].to(self.device, non_blocking=True)
+        batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
+        self.nb, _, self.height, self.width = batch["img"].shape  # batch size, channels, height, width
+        self.targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
+        self.targets = self.targets.to(self.device)
+        height, width = batch["img"].shape[2:]
+        self.targets[:, 2:] *= torch.tensor((width, height, width, height), device=self.device)  # to pixels
+        self.lb = [self.targets[self.targets[:, 0] == i, 1:]
+                   for i in range(self.nb)] if self.args.save_hybrid else []  # for autolabelling
+
+        return batch
+
+    def init_metrics(self, model):
+        if self.training:
+            head = de_parallel(model).model[-1]
+        else:
+            head = de_parallel(model).model.model[-1]
+
+        if self.data:
+            self.is_coco = isinstance(self.data.get('val'),
+                                      str) and self.data['val'].endswith(f'coco{os.sep}val2017.txt')
+            self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
+        self.nc = head.nc
+        self.names = model.names
+        if isinstance(self.names, (list, tuple)):  # old format
+            self.names = dict(enumerate(self.names))
+
+        self.iouv = torch.linspace(0.5, 0.95, 10, device=self.device)  # iou vector for mAP@0.5:0.95
+        self.niou = self.iouv.numel()
+        self.seen = 0
+        self.confusion_matrix = ConfusionMatrix(nc=self.nc)
+        self.metrics = Metric()
+        self.loss = torch.zeros(4, device=self.device)
+        self.jdict = []
+        self.stats = []
+
+    def get_desc(self):
+        return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)")
+
+    def postprocess(self, preds):
+        preds = ops.non_max_suppression(preds,
+                                        self.args.conf_thres,
+                                        self.args.iou_thres,
+                                        labels=self.lb,
+                                        multi_label=True,
+                                        agnostic=self.args.single_cls,
+                                        max_det=self.args.max_det)
+        return preds
+
+    def update_metrics(self, preds, batch):
+        # Metrics
+        for si, (pred) in enumerate(preds):
+            labels = self.targets[self.targets[:, 0] == si, 1:]
+            nl, npr = labels.shape[0], pred.shape[0]  # number of labels, predictions
+            shape = batch["ori_shape"][si]
+            # path = batch["shape"][si][0]
+            correct_bboxes = torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device)  # init
+            self.seen += 1
+
+            if npr == 0:
+                if nl:
+                    self.stats.append((correct_bboxes, *torch.zeros((2, 0), device=self.device), labels[:, 0]))
+                    if self.args.plots:
+                        self.confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
+                continue
+
+            # Predictions
+            if self.args.single_cls:
+                pred[:, 5] = 0
+            predn = pred.clone()
+            ops.scale_boxes(batch["img"][si].shape[1:], predn[:, :4], shape)  # native-space pred
+
+            # Evaluate
+            if nl:
+                tbox = ops.xywh2xyxy(labels[:, 1:5])  # target boxes
+                ops.scale_boxes(batch["img"][si].shape[1:], tbox, shape)  # native-space labels
+                labelsn = torch.cat((labels[:, 0:1], tbox), 1)  # native-space labels
+                correct_bboxes = self._process_batch(predn, labelsn, self.iouv)
+                # TODO: maybe remove these `self.` arguments as they already are member variable
+                if self.args.plots:
+                    self.confusion_matrix.process_batch(predn, labelsn)
+            self.stats.append((correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0]))  # (conf, pcls, tcls)
+
+            # TODO: Save/log
+            '''
+            if self.args.save_txt:
+                save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
+            if self.args.save_json:
+                pred_masks = scale_image(im[si].shape[1:],
+                                         pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape, shapes[si][1])
+                save_one_json(predn, jdict, path, class_map, pred_masks)  # append to COCO-JSON dictionary
+            # callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
+            '''
+
+    def get_stats(self):
+        stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)]  # to numpy
+        if len(stats) and stats[0].any():
+            results = ap_per_class(*stats, plot=self.args.plots, save_dir=self.save_dir, names=self.names)
+            self.metrics.update(results[2:])
+        self.nt_per_class = np.bincount(stats[3].astype(int), minlength=self.nc)  # number of targets per class
+        metrics = {"fitness": fitness_detection(np.array(self.metrics.mean_results()).reshape(1, -1))}
+        metrics |= zip(self.metric_keys, self.metrics.mean_results())
+        return metrics
+
+    def print_results(self):
+        pf = '%22s' + '%11i' * 2 + '%11.3g' * 4  # print format
+        self.logger.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
+        if self.nt_per_class.sum() == 0:
+            self.logger.warning(
+                f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
+
+        # Print results per class
+        if (self.args.verbose or (self.nc < 50 and not self.training)) and self.nc > 1 and len(self.stats):
+            for i, c in enumerate(self.metrics.ap_class_index):
+                self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
+
+        if self.args.plots:
+            self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
+
+    def _process_batch(self, detections, labels, iouv):
+        """
+        Return correct prediction matrix
+        Arguments:
+            detections (array[N, 6]), x1, y1, x2, y2, conf, class
+            labels (array[M, 5]), class, x1, y1, x2, y2
+        Returns:
+            correct (array[N, 10]), for 10 IoU levels
+        """
+        iou = box_iou(labels[:, 1:], detections[:, :4])
+        correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool)
+        correct_class = labels[:, 0:1] == detections[:, 5]
+        for i in range(len(iouv)):
+            x = torch.where((iou >= iouv[i]) & correct_class)  # IoU > threshold and classes match
+            if x[0].shape[0]:
+                matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]),
+                                    1).cpu().numpy()  # [label, detect, iou]
+                if x[0].shape[0] > 1:
+                    matches = matches[matches[:, 2].argsort()[::-1]]
+                    matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
+                    # matches = matches[matches[:, 2].argsort()[::-1]]
+                    matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
+                correct[matches[:, 1].astype(int), i] = True
+        return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
+
+    def get_dataloader(self, dataset_path, batch_size):
+        # TODO: manage splits differently
+        # calculate stride - check if model is initialized
+        gs = max(int(de_parallel(self.model).stride if self.model else 0), 32)
+        return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, mode="val")[0]
+
+    # TODO: align with train loss metrics
+    @property
+    def metric_keys(self):
+        return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP_0.5(B)", "metrics/mAP_0.5:0.95(B)"]
+
+    def plot_val_samples(self, batch, ni):
+        images = batch["img"]
+        cls = batch["cls"].squeeze(-1)
+        bboxes = batch["bboxes"]
+        paths = batch["im_file"]
+        batch_idx = batch["batch_idx"]
+        plot_images(images,
+                    batch_idx,
+                    cls,
+                    bboxes,
+                    paths=paths,
+                    fname=self.save_dir / f"val_batch{ni}_labels.jpg",
+                    names=self.names)
+
+    def plot_predictions(self, batch, preds, ni):
+        images = batch["img"]
+        paths = batch["im_file"]
+        plot_images(images, *output_to_target(preds, max_det=15), paths, self.save_dir / f'val_batch{ni}_pred.jpg',
+                    self.names)  # pred
+
+
+@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
+def val(cfg):
+    cfg.data = cfg.data or "coco128.yaml"
+    validator = DetectionValidator(args=cfg)
+    validator(model=cfg.model)
+
+
+if __name__ == "__main__":
+    val()
diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py
index 5a4af698..a5481d26 100644
--- a/ultralytics/yolo/v8/segment/train.py
+++ b/ultralytics/yolo/v8/segment/train.py
@@ -250,7 +250,7 @@ class SegmentationTrainer(BaseTrainer):
                               cls,
                               bboxes,
                               masks,
-                              paths,
+                              paths=paths,
                               fname=self.save_dir / f"train_batch{ni}.jpg")
 
     def plot_metrics(self):
diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py
index 18dead90..3784fd3d 100644
--- a/ultralytics/yolo/v8/segment/val.py
+++ b/ultralytics/yolo/v8/segment/val.py
@@ -252,7 +252,7 @@ class SegmentationValidator(BaseValidator):
         if len(self.plot_masks):
             plot_masks = torch.cat(self.plot_masks, dim=0)
         batch_idx, cls, bboxes, conf = output_to_target(preds[0], max_det=15)
-        plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf,
+        plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, conf, paths,
                               self.save_dir / f'val_batch{ni}_pred.jpg', self.names)  # pred
         self.plot_masks.clear()