From 527a97759b5ebee35be3a64a0a6b7b52356731db Mon Sep 17 00:00:00 2001
From: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Date: Sun, 28 May 2023 19:45:41 +0530
Subject: [PATCH] Revert loss head PR (#2873)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
---
 ultralytics/nn/tasks.py               |  31 ---
 ultralytics/yolo/engine/trainer.py    |   9 +-
 ultralytics/yolo/engine/validator.py  |   3 +-
 ultralytics/yolo/utils/loss.py        | 293 --------------------------
 ultralytics/yolo/v8/classify/train.py |   7 +
 ultralytics/yolo/v8/detect/train.py   | 107 ++++++++++
 ultralytics/yolo/v8/pose/train.py     | 104 +++++++++
 ultralytics/yolo/v8/segment/train.py  | 108 ++++++++++
 8 files changed, 335 insertions(+), 327 deletions(-)

diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py
index 28fa0336..05caf862 100644
--- a/ultralytics/nn/tasks.py
+++ b/ultralytics/nn/tasks.py
@@ -13,7 +13,6 @@ from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottlenec
                                     RTDETRDecoder, Segment)
 from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
 from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
-from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
 from ultralytics.yolo.utils.plotting import feature_visualization
 from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
                                                 intersect_dicts, make_divisible, model_info, scale_img, time_sync)
@@ -176,23 +175,6 @@ class BaseModel(nn.Module):
         if verbose:
             LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
 
-    def loss(self, batch, preds=None):
-        """
-        Compute loss
-
-        Args:
-            batch (dict): Batch to compute loss on
-            pred (torch.Tensor | List[torch.Tensor]): Predictions.
-        """
-        if not hasattr(self, 'criterion'):
-            self.criterion = self.init_criterion()
-
-        preds = self.forward(batch['img']) if preds is None else preds
-        return self.criterion(preds, batch)
-
-    def init_criterion(self):
-        raise NotImplementedError('compute_loss() needs to be implemented by task heads')
-
 
 class DetectionModel(BaseModel):
     """YOLOv8 detection model."""
@@ -269,9 +251,6 @@ class DetectionModel(BaseModel):
         y[-1] = y[-1][..., i:]  # small
         return y
 
-    def init_criterion(self):
-        return v8DetectionLoss(self)
-
 
 class SegmentationModel(DetectionModel):
     """YOLOv8 segmentation model."""
@@ -284,9 +263,6 @@ class SegmentationModel(DetectionModel):
         """Undocumented function."""
         raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!'))
 
-    def init_criterion(self):
-        return v8SegmentationLoss(self)
-
 
 class PoseModel(DetectionModel):
     """YOLOv8 pose model."""
@@ -300,9 +276,6 @@ class PoseModel(DetectionModel):
             cfg['kpt_shape'] = data_kpt_shape
         super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
 
-    def init_criterion(self):
-        return v8PoseLoss(self)
-
 
 class ClassificationModel(BaseModel):
     """YOLOv8 classification model."""
@@ -370,10 +343,6 @@ class ClassificationModel(BaseModel):
                 if m[i].out_channels != nc:
                     m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
 
-    def init_criterion(self):
-        """Compute the classification loss between predictions and true labels."""
-        return v8ClassificationLoss()
-
 
 class Ensemble(nn.ModuleList):
     """Ensemble of models."""
diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py
index c7754109..f220ad7f 100644
--- a/ultralytics/yolo/engine/trainer.py
+++ b/ultralytics/yolo/engine/trainer.py
@@ -325,7 +325,8 @@ class BaseTrainer:
                 # Forward
                 with torch.cuda.amp.autocast(self.amp):
                     batch = self.preprocess_batch(batch)
-                    self.loss, self.loss_items = de_parallel(self.model).loss(batch)
+                    preds = self.model(batch['img'])
+                    self.loss, self.loss_items = self.criterion(preds, batch)
                     if RANK != -1:
                         self.loss *= world_size
                     self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
@@ -495,6 +496,12 @@ class BaseTrainer:
         """Build dataset"""
         raise NotImplementedError('build_dataset function not implemented in trainer')
 
+    def criterion(self, preds, batch):
+        """
+        Returns loss and individual loss items as Tensor.
+        """
+        raise NotImplementedError('criterion function not implemented in trainer')
+
     def label_loss_items(self, loss_items=None, prefix='train'):
         """
         Returns a loss dict with labelled training loss items tensor
diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py
index 09f297bd..ad107a2c 100644
--- a/ultralytics/yolo/engine/validator.py
+++ b/ultralytics/yolo/engine/validator.py
@@ -162,8 +162,7 @@ class BaseValidator:
             # Loss
             with dt[2]:
                 if self.training:
-                    loss_items = model.loss(batch, preds)
-                    self.loss += loss_items[1]
+                    self.loss += trainer.criterion(preds, batch)[1]
 
             # Postprocess
             with dt[3]:
diff --git a/ultralytics/yolo/utils/loss.py b/ultralytics/yolo/utils/loss.py
index 5266982b..73aba68f 100644
--- a/ultralytics/yolo/utils/loss.py
+++ b/ultralytics/yolo/utils/loss.py
@@ -4,10 +4,6 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from ultralytics.yolo.utils.metrics import OKS_SIGMA
-from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
-from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
-
 from .metrics import bbox_iou
 from .tal import bbox2dist
 
@@ -77,292 +73,3 @@ class KeypointLoss(nn.Module):
         # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9)  # from formula
         e = d / (2 * self.sigmas) ** 2 / (area + 1e-9) / 2  # from cocoeval
         return kpt_loss_factor * ((1 - torch.exp(-e)) * kpt_mask).mean()
-
-
-# Criterion class for computing Detection training losses
-class v8DetectionLoss:
-
-    def __init__(self, model):  # model must be de-paralleled
-
-        device = next(model.parameters()).device  # get model device
-        h = model.args  # hyperparameters
-
-        m = model.model[-1]  # Detect() module
-        self.bce = nn.BCEWithLogitsLoss(reduction='none')
-        self.hyp = h
-        self.stride = m.stride  # model strides
-        self.nc = m.nc  # number of classes
-        self.no = m.no
-        self.reg_max = m.reg_max
-        self.device = device
-
-        self.use_dfl = m.reg_max > 1
-
-        self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
-        self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
-        self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
-
-    def preprocess(self, targets, batch_size, scale_tensor):
-        """Preprocesses the target counts and matches with the input batch size to output a tensor."""
-        if targets.shape[0] == 0:
-            out = torch.zeros(batch_size, 0, 5, device=self.device)
-        else:
-            i = targets[:, 0]  # image index
-            _, counts = i.unique(return_counts=True)
-            counts = counts.to(dtype=torch.int32)
-            out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
-            for j in range(batch_size):
-                matches = i == j
-                n = matches.sum()
-                if n:
-                    out[j, :n] = targets[matches, 1:]
-            out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
-        return out
-
-    def bbox_decode(self, anchor_points, pred_dist):
-        """Decode predicted object bounding box coordinates from anchor points and distribution."""
-        if self.use_dfl:
-            b, a, c = pred_dist.shape  # batch, anchors, channels
-            pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
-            # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
-            # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
-        return dist2bbox(pred_dist, anchor_points, xywh=False)
-
-    def __call__(self, preds, batch):
-        """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
-        loss = torch.zeros(3, device=self.device)  # box, cls, dfl
-        feats = preds[1] if isinstance(preds, tuple) else preds
-        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
-            (self.reg_max * 4, self.nc), 1)
-
-        pred_scores = pred_scores.permute(0, 2, 1).contiguous()
-        pred_distri = pred_distri.permute(0, 2, 1).contiguous()
-
-        dtype = pred_scores.dtype
-        batch_size = pred_scores.shape[0]
-        imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
-        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
-
-        # targets
-        targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
-        targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
-        gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
-        mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
-
-        # pboxes
-        pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
-
-        _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
-            pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
-            anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
-
-        target_scores_sum = max(target_scores.sum(), 1)
-
-        # cls loss
-        # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
-        loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE
-
-        # bbox loss
-        if fg_mask.sum():
-            target_bboxes /= stride_tensor
-            loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
-                                              target_scores_sum, fg_mask)
-
-        loss[0] *= self.hyp.box  # box gain
-        loss[1] *= self.hyp.cls  # cls gain
-        loss[2] *= self.hyp.dfl  # dfl gain
-
-        return loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)
-
-
-# Criterion class for computing training losses
-class v8SegmentationLoss(v8DetectionLoss):
-
-    def __init__(self, model, overlap=True):  # model must be de-paralleled
-        super().__init__(model)
-        self.nm = model.model[-1].nm  # number of masks
-        self.overlap = overlap
-
-    def __call__(self, preds, batch):
-        """Calculate and return the loss for the YOLO model."""
-        loss = torch.zeros(4, device=self.device)  # box, cls, dfl
-        feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
-        batch_size, _, mask_h, mask_w = proto.shape  # batch size, number of masks, mask height, mask width
-        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
-            (self.reg_max * 4, self.nc), 1)
-
-        # b, grids, ..
-        pred_scores = pred_scores.permute(0, 2, 1).contiguous()
-        pred_distri = pred_distri.permute(0, 2, 1).contiguous()
-        pred_masks = pred_masks.permute(0, 2, 1).contiguous()
-
-        dtype = pred_scores.dtype
-        imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
-        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
-
-        # targets
-        try:
-            batch_idx = batch['batch_idx'].view(-1, 1)
-            targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
-            targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
-            gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
-            mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
-        except RuntimeError as e:
-            raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
-                            "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
-                            "i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a "
-                            "correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' "
-                            'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e
-
-        # pboxes
-        pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
-
-        _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
-            pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
-            anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
-
-        target_scores_sum = max(target_scores.sum(), 1)
-
-        # cls loss
-        # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
-        loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE
-
-        if fg_mask.sum():
-            # bbox loss
-            loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
-                                              target_scores, target_scores_sum, fg_mask)
-            # masks loss
-            masks = batch['masks'].to(self.device).float()
-            if tuple(masks.shape[-2:]) != (mask_h, mask_w):  # downsample
-                masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
-
-            for i in range(batch_size):
-                if fg_mask[i].sum():
-                    mask_idx = target_gt_idx[i][fg_mask[i]]
-                    if self.overlap:
-                        gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0)
-                    else:
-                        gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
-                    xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
-                    marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
-                    mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)
-                    loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea)  # seg
-
-                # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
-                else:
-                    loss[1] += (proto * 0).sum() + (pred_masks * 0).sum()  # inf sums may lead to nan loss
-
-        # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
-        else:
-            loss[1] += (proto * 0).sum() + (pred_masks * 0).sum()  # inf sums may lead to nan loss
-
-        loss[0] *= self.hyp.box  # box gain
-        loss[1] *= self.hyp.box / batch_size  # seg gain
-        loss[2] *= self.hyp.cls  # cls gain
-        loss[3] *= self.hyp.dfl  # dfl gain
-
-        return loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)
-
-    def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
-        """Mask loss for one image."""
-        pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:])  # (n, 32) @ (32,80,80) -> (n,80,80)
-        loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
-        return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
-
-
-# Criterion class for computing training losses
-class v8PoseLoss(v8DetectionLoss):
-
-    def __init__(self, model):  # model must be de-paralleled
-        super().__init__(model)
-        self.kpt_shape = model.model[-1].kpt_shape
-        self.bce_pose = nn.BCEWithLogitsLoss()
-        is_pose = self.kpt_shape == [17, 3]
-        nkpt = self.kpt_shape[0]  # number of keypoints
-        sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
-        self.keypoint_loss = KeypointLoss(sigmas=sigmas)
-
-    def __call__(self, preds, batch):
-        """Calculate the total loss and detach it."""
-        loss = torch.zeros(5, device=self.device)  # box, cls, dfl, kpt_location, kpt_visibility
-        feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
-        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
-            (self.reg_max * 4, self.nc), 1)
-
-        # b, grids, ..
-        pred_scores = pred_scores.permute(0, 2, 1).contiguous()
-        pred_distri = pred_distri.permute(0, 2, 1).contiguous()
-        pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
-
-        dtype = pred_scores.dtype
-        imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
-        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
-
-        # targets
-        batch_size = pred_scores.shape[0]
-        batch_idx = batch['batch_idx'].view(-1, 1)
-        targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
-        targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
-        gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
-        mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
-
-        # pboxes
-        pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
-        pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape))  # (b, h*w, 17, 3)
-
-        _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
-            pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
-            anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
-
-        target_scores_sum = max(target_scores.sum(), 1)
-
-        # cls loss
-        # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
-        loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE
-
-        # bbox loss
-        if fg_mask.sum():
-            target_bboxes /= stride_tensor
-            loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
-                                              target_scores_sum, fg_mask)
-            keypoints = batch['keypoints'].to(self.device).float().clone()
-            keypoints[..., 0] *= imgsz[1]
-            keypoints[..., 1] *= imgsz[0]
-            for i in range(batch_size):
-                if fg_mask[i].sum():
-                    idx = target_gt_idx[i][fg_mask[i]]
-                    gt_kpt = keypoints[batch_idx.view(-1) == i][idx]  # (n, 51)
-                    gt_kpt[..., 0] /= stride_tensor[fg_mask[i]]
-                    gt_kpt[..., 1] /= stride_tensor[fg_mask[i]]
-                    area = xyxy2xywh(target_bboxes[i][fg_mask[i]])[:, 2:].prod(1, keepdim=True)
-                    pred_kpt = pred_kpts[i][fg_mask[i]]
-                    kpt_mask = gt_kpt[..., 2] != 0
-                    loss[1] += self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area)  # pose loss
-                    # kpt_score loss
-                    if pred_kpt.shape[-1] == 3:
-                        loss[2] += self.bce_pose(pred_kpt[..., 2], kpt_mask.float())  # keypoint obj loss
-
-        loss[0] *= self.hyp.box  # box gain
-        loss[1] *= self.hyp.pose / batch_size  # pose gain
-        loss[2] *= self.hyp.kobj / batch_size  # kobj gain
-        loss[3] *= self.hyp.cls  # cls gain
-        loss[4] *= self.hyp.dfl  # dfl gain
-
-        return loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)
-
-    def kpts_decode(self, anchor_points, pred_kpts):
-        """Decodes predicted keypoints to image coordinates."""
-        y = pred_kpts.clone()
-        y[..., :2] *= 2.0
-        y[..., 0] += anchor_points[:, [0]] - 0.5
-        y[..., 1] += anchor_points[:, [1]] - 0.5
-        return y
-
-
-class v8ClassificationLoss:
-
-    def __call__(self, preds, batch):
-        """Compute the classification loss between predictions and true labels."""
-        loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / 64  # TODO: remove hardcoding
-        loss_items = loss.detach()
-        return loss, loss_items
diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py
index 29496446..6c8b6574 100644
--- a/ultralytics/yolo/v8/classify/train.py
+++ b/ultralytics/yolo/v8/classify/train.py
@@ -41,6 +41,7 @@ class ClassificationTrainer(BaseTrainer):
                 m.p = self.args.dropout  # set dropout
         for p in model.parameters():
             p.requires_grad = True  # for training
+
         return model
 
     def setup_model(self):
@@ -102,6 +103,12 @@ class ClassificationTrainer(BaseTrainer):
         self.loss_names = ['loss']
         return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
 
+    def criterion(self, preds, batch):
+        """Compute the classification loss between predictions and true labels."""
+        loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
+        loss_items = loss.detach()
+        return loss, loss_items
+
     def label_loss_items(self, loss_items=None, prefix='train'):
         """
         Returns a loss dict with labelled training loss items tensor
diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py
index 1b475ed0..7f0fc1df 100644
--- a/ultralytics/yolo/v8/detect/train.py
+++ b/ultralytics/yolo/v8/detect/train.py
@@ -2,6 +2,8 @@
 from copy import copy
 
 import numpy as np
+import torch
+import torch.nn as nn
 
 from ultralytics.nn.tasks import DetectionModel
 from ultralytics.yolo import v8
@@ -9,7 +11,10 @@ from ultralytics.yolo.data import build_dataloader, build_yolo_dataset
 from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
 from ultralytics.yolo.engine.trainer import BaseTrainer
 from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
+from ultralytics.yolo.utils.loss import BboxLoss
+from ultralytics.yolo.utils.ops import xywh2xyxy
 from ultralytics.yolo.utils.plotting import plot_images, plot_labels, plot_results
+from ultralytics.yolo.utils.tal import TaskAlignedAssigner, dist2bbox, make_anchors
 from ultralytics.yolo.utils.torch_utils import de_parallel, torch_distributed_zero_first
 
 
@@ -86,6 +91,12 @@ class DetectionTrainer(BaseTrainer):
         self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
         return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
 
+    def criterion(self, preds, batch):
+        """Compute loss for YOLO prediction and ground-truth."""
+        if not hasattr(self, 'compute_loss'):
+            self.compute_loss = Loss(de_parallel(self.model))
+        return self.compute_loss(preds, batch)
+
     def label_loss_items(self, loss_items=None, prefix='train'):
         """
         Returns a loss dict with labelled training loss items tensor
@@ -124,6 +135,102 @@ class DetectionTrainer(BaseTrainer):
         plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)
 
 
+# Criterion class for computing training losses
+class Loss:
+
+    def __init__(self, model):  # model must be de-paralleled
+
+        device = next(model.parameters()).device  # get model device
+        h = model.args  # hyperparameters
+
+        m = model.model[-1]  # Detect() module
+        self.bce = nn.BCEWithLogitsLoss(reduction='none')
+        self.hyp = h
+        self.stride = m.stride  # model strides
+        self.nc = m.nc  # number of classes
+        self.no = m.no
+        self.reg_max = m.reg_max
+        self.device = device
+
+        self.use_dfl = m.reg_max > 1
+
+        self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
+        self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
+        self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
+
+    def preprocess(self, targets, batch_size, scale_tensor):
+        """Preprocesses the target counts and matches with the input batch size to output a tensor."""
+        if targets.shape[0] == 0:
+            out = torch.zeros(batch_size, 0, 5, device=self.device)
+        else:
+            i = targets[:, 0]  # image index
+            _, counts = i.unique(return_counts=True)
+            counts = counts.to(dtype=torch.int32)
+            out = torch.zeros(batch_size, counts.max(), 5, device=self.device)
+            for j in range(batch_size):
+                matches = i == j
+                n = matches.sum()
+                if n:
+                    out[j, :n] = targets[matches, 1:]
+            out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
+        return out
+
+    def bbox_decode(self, anchor_points, pred_dist):
+        """Decode predicted object bounding box coordinates from anchor points and distribution."""
+        if self.use_dfl:
+            b, a, c = pred_dist.shape  # batch, anchors, channels
+            pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
+            # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
+            # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
+        return dist2bbox(pred_dist, anchor_points, xywh=False)
+
+    def __call__(self, preds, batch):
+        """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
+        loss = torch.zeros(3, device=self.device)  # box, cls, dfl
+        feats = preds[1] if isinstance(preds, tuple) else preds
+        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
+            (self.reg_max * 4, self.nc), 1)
+
+        pred_scores = pred_scores.permute(0, 2, 1).contiguous()
+        pred_distri = pred_distri.permute(0, 2, 1).contiguous()
+
+        dtype = pred_scores.dtype
+        batch_size = pred_scores.shape[0]
+        imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
+        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
+
+        # targets
+        targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
+        targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
+        gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
+        mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
+
+        # pboxes
+        pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
+
+        _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
+            pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
+            anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
+
+        target_scores_sum = max(target_scores.sum(), 1)
+
+        # cls loss
+        # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
+        loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE
+
+        # bbox loss
+        if fg_mask.sum():
+            target_bboxes /= stride_tensor
+            loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
+                                              target_scores_sum, fg_mask)
+
+        loss[0] *= self.hyp.box  # box gain
+        loss[1] *= self.hyp.cls  # cls gain
+        loss[2] *= self.hyp.dfl  # dfl gain
+
+        return loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)
+
+
 def train(cfg=DEFAULT_CFG, use_python=False):
     """Train and optimize YOLO model given training data and device."""
     model = cfg.model or 'yolov8n.pt'
diff --git a/ultralytics/yolo/v8/pose/train.py b/ultralytics/yolo/v8/pose/train.py
index af3043c1..b7890f89 100644
--- a/ultralytics/yolo/v8/pose/train.py
+++ b/ultralytics/yolo/v8/pose/train.py
@@ -2,10 +2,19 @@
 
 from copy import copy
 
+import torch
+import torch.nn as nn
+
 from ultralytics.nn.tasks import PoseModel
 from ultralytics.yolo import v8
 from ultralytics.yolo.utils import DEFAULT_CFG
+from ultralytics.yolo.utils.loss import KeypointLoss
+from ultralytics.yolo.utils.metrics import OKS_SIGMA
+from ultralytics.yolo.utils.ops import xyxy2xywh
 from ultralytics.yolo.utils.plotting import plot_images, plot_results
+from ultralytics.yolo.utils.tal import make_anchors
+from ultralytics.yolo.utils.torch_utils import de_parallel
+from ultralytics.yolo.v8.detect.train import Loss
 
 
 # BaseTrainer python usage
@@ -36,6 +45,12 @@ class PoseTrainer(v8.detect.DetectionTrainer):
         self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
         return v8.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
 
+    def criterion(self, preds, batch):
+        """Computes pose loss for the YOLO model."""
+        if not hasattr(self, 'compute_loss'):
+            self.compute_loss = PoseLoss(de_parallel(self.model))
+        return self.compute_loss(preds, batch)
+
     def plot_training_samples(self, batch, ni):
         """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
         images = batch['img']
@@ -58,6 +73,95 @@ class PoseTrainer(v8.detect.DetectionTrainer):
         plot_results(file=self.csv, pose=True, on_plot=self.on_plot)  # save results.png
 
 
+# Criterion class for computing training losses
+class PoseLoss(Loss):
+
+    def __init__(self, model):  # model must be de-paralleled
+        super().__init__(model)
+        self.kpt_shape = model.model[-1].kpt_shape
+        self.bce_pose = nn.BCEWithLogitsLoss()
+        is_pose = self.kpt_shape == [17, 3]
+        nkpt = self.kpt_shape[0]  # number of keypoints
+        sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
+        self.keypoint_loss = KeypointLoss(sigmas=sigmas)
+
+    def __call__(self, preds, batch):
+        """Calculate the total loss and detach it."""
+        loss = torch.zeros(5, device=self.device)  # box, cls, dfl, kpt_location, kpt_visibility
+        feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
+        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
+            (self.reg_max * 4, self.nc), 1)
+
+        # b, grids, ..
+        pred_scores = pred_scores.permute(0, 2, 1).contiguous()
+        pred_distri = pred_distri.permute(0, 2, 1).contiguous()
+        pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
+
+        dtype = pred_scores.dtype
+        imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
+        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
+
+        # targets
+        batch_size = pred_scores.shape[0]
+        batch_idx = batch['batch_idx'].view(-1, 1)
+        targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
+        targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
+        gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
+        mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
+
+        # pboxes
+        pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
+        pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape))  # (b, h*w, 17, 3)
+
+        _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
+            pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
+            anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
+
+        target_scores_sum = max(target_scores.sum(), 1)
+
+        # cls loss
+        # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
+        loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE
+
+        # bbox loss
+        if fg_mask.sum():
+            target_bboxes /= stride_tensor
+            loss[0], loss[4] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
+                                              target_scores_sum, fg_mask)
+            keypoints = batch['keypoints'].to(self.device).float().clone()
+            keypoints[..., 0] *= imgsz[1]
+            keypoints[..., 1] *= imgsz[0]
+            for i in range(batch_size):
+                if fg_mask[i].sum():
+                    idx = target_gt_idx[i][fg_mask[i]]
+                    gt_kpt = keypoints[batch_idx.view(-1) == i][idx]  # (n, 51)
+                    gt_kpt[..., 0] /= stride_tensor[fg_mask[i]]
+                    gt_kpt[..., 1] /= stride_tensor[fg_mask[i]]
+                    area = xyxy2xywh(target_bboxes[i][fg_mask[i]])[:, 2:].prod(1, keepdim=True)
+                    pred_kpt = pred_kpts[i][fg_mask[i]]
+                    kpt_mask = gt_kpt[..., 2] != 0
+                    loss[1] += self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area)  # pose loss
+                    # kpt_score loss
+                    if pred_kpt.shape[-1] == 3:
+                        loss[2] += self.bce_pose(pred_kpt[..., 2], kpt_mask.float())  # keypoint obj loss
+
+        loss[0] *= self.hyp.box  # box gain
+        loss[1] *= self.hyp.pose / batch_size  # pose gain
+        loss[2] *= self.hyp.kobj / batch_size  # kobj gain
+        loss[3] *= self.hyp.cls  # cls gain
+        loss[4] *= self.hyp.dfl  # dfl gain
+
+        return loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)
+
+    def kpts_decode(self, anchor_points, pred_kpts):
+        """Decodes predicted keypoints to image coordinates."""
+        y = pred_kpts.clone()
+        y[..., :2] *= 2.0
+        y[..., 0] += anchor_points[:, [0]] - 0.5
+        y[..., 1] += anchor_points[:, [1]] - 0.5
+        return y
+
+
 def train(cfg=DEFAULT_CFG, use_python=False):
     """Train the YOLO model on the given data and device."""
     model = cfg.model or 'yolov8n-pose.yaml'
diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py
index ab66cf06..32ce510e 100644
--- a/ultralytics/yolo/v8/segment/train.py
+++ b/ultralytics/yolo/v8/segment/train.py
@@ -1,10 +1,17 @@
 # Ultralytics YOLO 🚀, AGPL-3.0 license
 from copy import copy
 
+import torch
+import torch.nn.functional as F
+
 from ultralytics.nn.tasks import SegmentationModel
 from ultralytics.yolo import v8
 from ultralytics.yolo.utils import DEFAULT_CFG, RANK
+from ultralytics.yolo.utils.ops import crop_mask, xyxy2xywh
 from ultralytics.yolo.utils.plotting import plot_images, plot_results
+from ultralytics.yolo.utils.tal import make_anchors
+from ultralytics.yolo.utils.torch_utils import de_parallel
+from ultralytics.yolo.v8.detect.train import Loss
 
 
 # BaseTrainer python usage
@@ -30,6 +37,12 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
         self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
         return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
 
+    def criterion(self, preds, batch):
+        """Returns the computed loss using the SegLoss class on the given predictions and batch."""
+        if not hasattr(self, 'compute_loss'):
+            self.compute_loss = SegLoss(de_parallel(self.model), overlap=self.args.overlap_mask)
+        return self.compute_loss(preds, batch)
+
     def plot_training_samples(self, batch, ni):
         """Creates a plot of training sample images with labels and box coordinates."""
         plot_images(batch['img'],
@@ -46,6 +59,101 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
         plot_results(file=self.csv, segment=True, on_plot=self.on_plot)  # save results.png
 
 
+# Criterion class for computing training losses
+class SegLoss(Loss):
+
+    def __init__(self, model, overlap=True):  # model must be de-paralleled
+        super().__init__(model)
+        self.nm = model.model[-1].nm  # number of masks
+        self.overlap = overlap
+
+    def __call__(self, preds, batch):
+        """Calculate and return the loss for the YOLO model."""
+        loss = torch.zeros(4, device=self.device)  # box, cls, dfl
+        feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
+        batch_size, _, mask_h, mask_w = proto.shape  # batch size, number of masks, mask height, mask width
+        pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
+            (self.reg_max * 4, self.nc), 1)
+
+        # b, grids, ..
+        pred_scores = pred_scores.permute(0, 2, 1).contiguous()
+        pred_distri = pred_distri.permute(0, 2, 1).contiguous()
+        pred_masks = pred_masks.permute(0, 2, 1).contiguous()
+
+        dtype = pred_scores.dtype
+        imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0]  # image size (h,w)
+        anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
+
+        # targets
+        try:
+            batch_idx = batch['batch_idx'].view(-1, 1)
+            targets = torch.cat((batch_idx, batch['cls'].view(-1, 1), batch['bboxes']), 1)
+            targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
+            gt_labels, gt_bboxes = targets.split((1, 4), 2)  # cls, xyxy
+            mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
+        except RuntimeError as e:
+            raise TypeError('ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n'
+                            "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
+                            "i.e. 'yolo train model=yolov8n-seg.pt data=coco128.yaml'.\nVerify your dataset is a "
+                            "correctly formatted 'segment' dataset using 'data=coco128-seg.yaml' "
+                            'as an example.\nSee https://docs.ultralytics.com/tasks/segment/ for help.') from e
+
+        # pboxes
+        pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
+
+        _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
+            pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
+            anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
+
+        target_scores_sum = max(target_scores.sum(), 1)
+
+        # cls loss
+        # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
+        loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE
+
+        if fg_mask.sum():
+            # bbox loss
+            loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
+                                              target_scores, target_scores_sum, fg_mask)
+            # masks loss
+            masks = batch['masks'].to(self.device).float()
+            if tuple(masks.shape[-2:]) != (mask_h, mask_w):  # downsample
+                masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
+
+            for i in range(batch_size):
+                if fg_mask[i].sum():
+                    mask_idx = target_gt_idx[i][fg_mask[i]]
+                    if self.overlap:
+                        gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0)
+                    else:
+                        gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
+                    xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
+                    marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
+                    mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)
+                    loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea)  # seg
+
+                # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
+                else:
+                    loss[1] += (proto * 0).sum() + (pred_masks * 0).sum()  # inf sums may lead to nan loss
+
+        # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
+        else:
+            loss[1] += (proto * 0).sum() + (pred_masks * 0).sum()  # inf sums may lead to nan loss
+
+        loss[0] *= self.hyp.box  # box gain
+        loss[1] *= self.hyp.box / batch_size  # seg gain
+        loss[2] *= self.hyp.cls  # cls gain
+        loss[3] *= self.hyp.dfl  # dfl gain
+
+        return loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)
+
+    def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
+        """Mask loss for one image."""
+        pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:])  # (n, 32) @ (32,80,80) -> (n,80,80)
+        loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
+        return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
+
+
 def train(cfg=DEFAULT_CFG, use_python=False):
     """Train a YOLO segmentation model based on passed arguments."""
     model = cfg.model or 'yolov8n-seg.pt'