diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index 4b5364bc..7ed500bd 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -4,6 +4,8 @@
 name: Publish to PyPI and Deploy Docs
 
 on:
+  push:
+    branches: [main]
   workflow_dispatch:
     inputs:
       pypi:
@@ -12,8 +14,6 @@ on:
       docs:
         type: boolean
         description: Deploy Docs
-  push:
-    branches: [main]
 
 jobs:
   publish:
@@ -35,9 +35,9 @@ jobs:
       - name: Check PyPI version
         shell: python
         run: |
+          import os
           import pkg_resources as pkg
           import ultralytics
-          import os
           from ultralytics.yolo.utils.checks import check_latest_pypi_version
 
           v_local = pkg.parse_version(ultralytics.__version__).release
@@ -45,7 +45,7 @@ jobs:
           print(f'Local version is {v_local}')
           print(f'PyPI version is {v_pypi}')
           d = [a - b for a, b in zip(v_local, v_pypi)]  # diff
-          increment = (d[0] == d[1] == 0) and d[2] == 1  # only patch increment by 1
+          increment = (d[0] == d[1] == 0) and d[2] == 1  # only publish if patch version increments by 1
           os.system(f'echo "increment={increment}" >> $GITHUB_OUTPUT')
           if increment:
               print('Local version is higher than PyPI version. Publishing new version to PyPI ✅.')
@@ -64,4 +64,4 @@ jobs:
         run: |
           mkdocs gh-deploy || true
           git checkout gh-pages
-          git push https://github.com/ultralytics/docs gh-pages --force
+          git push https://${{ secrets.PERSONAL_ACCESS_TOKEN }}@github.com/ultralytics/docs gh-pages --force
diff --git a/docs/cfg.md b/docs/cfg.md
index 4ec5a5b6..d7e2c60d 100644
--- a/docs/cfg.md
+++ b/docs/cfg.md
@@ -110,7 +110,6 @@ task.
 | mask_ratio      | 4      | mask downsample ratio (segment train only)                                     |
 | dropout         | 0.0    | use dropout regularization (classify train only)                               |
 | val             | True   | validate/test during training                                                  |
-| min_memory      | False  | minimize memory footprint loss function, choices=[False, True, <roll_out_thr>] |
 
 ### Prediction
 
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 2e78eaab..42bd9478 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
 # Ultralytics YOLO 🚀, GPL-3.0 license
 
-__version__ = '8.0.51'
+__version__ = '8.0.52'
 
 from ultralytics.yolo.engine.model import YOLO
 from ultralytics.yolo.utils.checks import check_yolo as checks
diff --git a/ultralytics/hub/__init__.py b/ultralytics/hub/__init__.py
index 357a4206..bd2ae2c7 100644
--- a/ultralytics/hub/__init__.py
+++ b/ultralytics/hub/__init__.py
@@ -5,13 +5,9 @@ import requests
 from ultralytics.hub.auth import Auth
 from ultralytics.hub.session import HUBTrainingSession
 from ultralytics.hub.utils import PREFIX, split_key
-from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_LIST
 from ultralytics.yolo.engine.model import YOLO
 from ultralytics.yolo.utils import LOGGER, emojis
 
-# Define all export formats
-EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ['ultralytics_tflite', 'ultralytics_coreml']
-
 
 def start(key=''):
     """
@@ -63,9 +59,15 @@ def reset_model(key=''):
     LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
 
 
+def export_fmts_hub():
+    # Returns a list of HUB-supported export formats
+    from ultralytics.yolo.engine.exporter import export_formats
+    return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
+
+
 def export_model(key='', format='torchscript'):
     # Export a model to all formats
-    assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
+    assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
     api_key, model_id = split_key(key)
     r = requests.post('https://api.ultralytics.com/export',
                       json={
@@ -78,7 +80,7 @@ def export_model(key='', format='torchscript'):
 
 def get_export(key='', format='torchscript'):
     # Get an exported model dictionary with download URL
-    assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
+    assert format in export_fmts_hub, f"Unsupported export format '{format}', valid formats are {export_fmts_hub}"
     api_key, model_id = split_key(key)
     r = requests.post('https://api.ultralytics.com/get-export',
                       json={
diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py
index bf80c1b3..a0b79d4f 100644
--- a/ultralytics/nn/autobackend.py
+++ b/ultralytics/nn/autobackend.py
@@ -184,7 +184,7 @@ class AutoBackend(nn.Module):
             LOGGER.info(f'Loading {w} for CoreML inference...')
             import coremltools as ct
             model = ct.models.MLModel(w)
-            metadata = model.user_defined_metadata
+            metadata = dict(model.user_defined_metadata)
         elif saved_model:  # TF SavedModel
             LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
             import tensorflow as tf
@@ -256,10 +256,10 @@ class AutoBackend(nn.Module):
             nhwc = model.runtime.startswith("tensorflow")
             '''
         else:
-            from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_TABLE
+            from ultralytics.yolo.engine.exporter import export_formats
             raise TypeError(f"model='{w}' is not a supported model format. "
                             'See https://docs.ultralytics.com/tasks/detection/#export for help.'
-                            f'\n\n{EXPORT_FORMATS_TABLE}')
+                            f'\n\n{export_formats()}')
 
         # Load external metadata YAML
         if isinstance(metadata, (str, Path)) and Path(metadata).exists():
diff --git a/ultralytics/nn/autoshape.py b/ultralytics/nn/autoshape.py
index 8b1b9205..3c983dc6 100644
--- a/ultralytics/nn/autoshape.py
+++ b/ultralytics/nn/autoshape.py
@@ -8,7 +8,6 @@ from pathlib import Path
 
 import cv2
 import numpy as np
-import pandas as pd
 import requests
 import torch
 import torch.nn as nn
@@ -204,12 +203,13 @@ class Detections:
 
     def pandas(self):
         # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
+        import pandas
         new = copy(self)  # return copy
         ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name'  # xyxy columns
         cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name'  # xywh columns
         for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
             a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)]  # update
-            setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
+            setattr(new, k, [pandas.DataFrame(x, columns=c) for x in a])
         return new
 
     def tolist(self):
diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py
index c453b696..65ab0b14 100644
--- a/ultralytics/nn/tasks.py
+++ b/ultralytics/nn/tasks.py
@@ -122,7 +122,7 @@ class BaseModel(nn.Module):
         bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)  # normalization layers, i.e. BatchNorm2d()
         return sum(isinstance(v, bn) for v in self.modules()) < thresh  # True if < 'thresh' BatchNorm layers in model
 
-    def info(self, verbose=False, imgsz=640):
+    def info(self, verbose=True, imgsz=640):
         """
         Prints model information
 
diff --git a/ultralytics/tracker/utils/matching.py b/ultralytics/tracker/utils/matching.py
index 02ff75cf..561f385f 100644
--- a/ultralytics/tracker/utils/matching.py
+++ b/ultralytics/tracker/utils/matching.py
@@ -36,18 +36,26 @@ def _indices_to_matches(cost_matrix, indices, thresh):
     return matches, unmatched_a, unmatched_b
 
 
-def linear_assignment(cost_matrix, thresh):
+def linear_assignment(cost_matrix, thresh, use_lap=True):
+    # Linear assignment implementations with scipy and lap.lapjv
     if cost_matrix.size == 0:
         return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
-    matches, unmatched_a, unmatched_b = [], [], []
 
-    # TODO: investigate scipy.optimize.linear_sum_assignment() for lap.lapjv()
-    cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
+    if use_lap:
+        _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
+        matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
+        unmatched_a = np.where(x < 0)[0]
+        unmatched_b = np.where(y < 0)[0]
+    else:
+        # Scipy linear sum assignment is NOT working correctly, DO NOT USE
+        y, x = scipy.optimize.linear_sum_assignment(cost_matrix)  # row y, col x
+        matches = np.asarray([[i, x] for i, x in enumerate(x) if cost_matrix[i, x] <= thresh])
+        unmatched = np.ones(cost_matrix.shape)
+        for i, xi in matches:
+            unmatched[i, xi] = 0.0
+        unmatched_a = np.where(unmatched.all(1))[0]
+        unmatched_b = np.where(unmatched.all(0))[0]
 
-    matches.extend([ix, mx] for ix, mx in enumerate(x) if mx >= 0)
-    unmatched_a = np.where(x < 0)[0]
-    unmatched_b = np.where(y < 0)[0]
-    matches = np.asarray(matches)
     return matches, unmatched_a, unmatched_b
 
 
diff --git a/ultralytics/yolo/cfg/default.yaml b/ultralytics/yolo/cfg/default.yaml
index f1e6e97a..e727b0fd 100644
--- a/ultralytics/yolo/cfg/default.yaml
+++ b/ultralytics/yolo/cfg/default.yaml
@@ -17,7 +17,7 @@ cache: False  # True/ram, disk or False. Use cache for data loading
 device:  # device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
 workers: 8  # number of worker threads for data loading (per RANK if DDP)
 project:  # project name
-name:  # experiment name
+name:  # experiment name, results saved to 'project/name' directory
 exist_ok: False  # whether to overwrite existing experiment
 pretrained: False  # whether to use a pretrained model
 optimizer: SGD  # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
@@ -30,7 +30,6 @@ rect: False  # support rectangular training if mode='train', support rectangular
 cos_lr: False  # use cosine learning rate scheduler
 close_mosaic: 10  # disable mosaic augmentation for final 10 epochs
 resume: False  # resume training from last checkpoint
-min_memory: False  # minimize memory footprint loss function, choices=[False, True, <roll_out_thr>]
 # Segmentation
 overlap_mask: True  # masks should overlap during training (segment train only)
 mask_ratio: 4  # mask downsample ratio (segment train only)
diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py
index 9b15ee6f..dff19f1e 100644
--- a/ultralytics/yolo/data/utils.py
+++ b/ultralytics/yolo/data/utils.py
@@ -23,8 +23,8 @@ from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file
 from ultralytics.yolo.utils.ops import segments2boxes
 
 HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
-IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm'  # include image suffixes
-VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv'  # include video suffixes
+IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm'  # image suffixes
+VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv', 'webm'  # video suffixes
 LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
 RANK = int(os.getenv('RANK', -1))
 PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true'  # global pin_memory for dataloaders
diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py
index 7f690c6a..cb0fc44a 100644
--- a/ultralytics/yolo/engine/exporter.py
+++ b/ultralytics/yolo/engine/exporter.py
@@ -57,16 +57,13 @@ from collections import defaultdict
 from copy import deepcopy
 from pathlib import Path
 
-import numpy as np
-import pandas as pd
 import torch
 
 from ultralytics.nn.autobackend import check_class_names
 from ultralytics.nn.modules import C2f, Detect, Segment
 from ultralytics.nn.tasks import DetectionModel, SegmentationModel
 from ultralytics.yolo.cfg import get_cfg
-from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
-from ultralytics.yolo.data.utils import IMAGENET_MEAN, IMAGENET_STD, check_det_dataset
+from ultralytics.yolo.data.utils import IMAGENET_MEAN, IMAGENET_STD
 from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, __version__, callbacks, colorstr,
                                     get_default_args, yaml_save)
 from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version
@@ -79,6 +76,7 @@ ARM64 = platform.machine() in ('arm64', 'aarch64')
 
 def export_formats():
     # YOLOv8 export formats
+    import pandas
     x = [
         ['PyTorch', '-', '.pt', True, True],
         ['TorchScript', 'torchscript', '.torchscript', True, True],
@@ -92,11 +90,7 @@ def export_formats():
         ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
         ['TensorFlow.js', 'tfjs', '_web_model', True, False],
         ['PaddlePaddle', 'paddle', '_paddle_model', True, True], ]
-    return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
-
-
-EXPORT_FORMATS_LIST = list(export_formats()['Argument'][1:])
-EXPORT_FORMATS_TABLE = str(export_formats())
+    return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
 
 
 def gd_outputs(gd):
@@ -537,7 +531,7 @@ class Exporter:
         # Export to TF
         int8 = '-oiqt -qt per-tensor' if self.args.int8 else ''
         cmd = f'onnx2tf -i {f_onnx} -o {f} -nuo --non_verbose {int8}'
-        LOGGER.info(f'\n{prefix} running {cmd}')
+        LOGGER.info(f"\n{prefix} running '{cmd}'")
         subprocess.run(cmd, shell=True)
         yaml_save(f / 'metadata.yaml', self.metadata)  # add metadata.yaml
 
@@ -574,47 +568,47 @@ class Exporter:
         LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
         saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model'))
         if self.args.int8:
-            f = saved_model / (self.file.stem + '_integer_quant.tflite')  # fp32 in/out
+            f = saved_model / f'{self.file.stem}_integer_quant.tflite'  # fp32 in/out
         elif self.args.half:
-            f = saved_model / (self.file.stem + '_float16.tflite')
+            f = saved_model / f'{self.file.stem}_float16.tflite'
         else:
-            f = saved_model / (self.file.stem + '_float32.tflite')
-        return str(f), None  # noqa
+            f = saved_model / f'{self.file.stem}_float32.tflite'
+        return str(f), None
 
-        # OLD VERSION BELOW ---------------------------------------------------------------
-        batch_size, ch, *imgsz = list(self.im.shape)  # BCHW
-        f = str(self.file).replace(self.file.suffix, '-fp16.tflite')
-
-        converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
-        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
-        converter.target_spec.supported_types = [tf.float16]
-        converter.optimizations = [tf.lite.Optimize.DEFAULT]
-        if self.args.int8:
-
-            def representative_dataset_gen(dataset, n_images=100):
-                # Dataset generator for use with converter.representative_dataset, returns a generator of np arrays
-                for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
-                    im = np.transpose(img, [1, 2, 0])
-                    im = np.expand_dims(im, axis=0).astype(np.float32)
-                    im /= 255
-                    yield [im]
-                    if n >= n_images:
-                        break
-
-            dataset = LoadImages(check_det_dataset(self.args.data)['train'], imgsz=imgsz, auto=False)
-            converter.representative_dataset = lambda: representative_dataset_gen(dataset, n_images=100)
-            converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
-            converter.target_spec.supported_types = []
-            converter.inference_input_type = tf.uint8  # or tf.int8
-            converter.inference_output_type = tf.uint8  # or tf.int8
-            converter.experimental_new_quantizer = True
-            f = str(self.file).replace(self.file.suffix, '-int8.tflite')
-        if nms or agnostic_nms:
-            converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
-
-        tflite_model = converter.convert()
-        open(f, 'wb').write(tflite_model)
-        return f, None
+        # # OLD TFLITE EXPORT CODE BELOW -------------------------------------------------------------------------------
+        # batch_size, ch, *imgsz = list(self.im.shape)  # BCHW
+        # f = str(self.file).replace(self.file.suffix, '-fp16.tflite')
+        #
+        # converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+        # converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
+        # converter.target_spec.supported_types = [tf.float16]
+        # converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        # if self.args.int8:
+        #
+        #     def representative_dataset_gen(dataset, n_images=100):
+        #         # Dataset generator for use with converter.representative_dataset, returns a generator of np arrays
+        #         for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
+        #             im = np.transpose(img, [1, 2, 0])
+        #             im = np.expand_dims(im, axis=0).astype(np.float32)
+        #             im /= 255
+        #             yield [im]
+        #             if n >= n_images:
+        #                 break
+        #
+        #     dataset = LoadImages(check_det_dataset(self.args.data)['train'], imgsz=imgsz, auto=False)
+        #     converter.representative_dataset = lambda: representative_dataset_gen(dataset, n_images=100)
+        #     converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        #     converter.target_spec.supported_types = []
+        #     converter.inference_input_type = tf.uint8  # or tf.int8
+        #     converter.inference_output_type = tf.uint8  # or tf.int8
+        #     converter.experimental_new_quantizer = True
+        #     f = str(self.file).replace(self.file.suffix, '-int8.tflite')
+        # if nms or agnostic_nms:
+        #     converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
+        #
+        # tflite_model = converter.convert()
+        # open(f, 'wb').write(tflite_model)
+        # return f, None
 
     @try_export
     def _export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
@@ -638,6 +632,7 @@ class Exporter:
         f = str(tflite_model).replace('.tflite', '_edgetpu.tflite')  # Edge TPU model
 
         cmd = f'edgetpu_compiler -s -d -k 10 --out_dir {Path(f).parent} {tflite_model}'
+        LOGGER.info(f"{prefix} running '{cmd}'")
         subprocess.run(cmd.split(), check=True)
         self._add_tflite_metadata(f)
         return f, None
diff --git a/ultralytics/yolo/engine/results.py b/ultralytics/yolo/engine/results.py
index ef4ec942..a8a2120b 100644
--- a/ultralytics/yolo/engine/results.py
+++ b/ultralytics/yolo/engine/results.py
@@ -48,7 +48,7 @@ class Results:
         self.probs = probs if probs is not None else None
         self.names = names
         self.path = path
-        self._keys = (k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None)
+        self._keys = [k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None]
 
     def pandas(self):
         pass
@@ -122,8 +122,7 @@ class Results:
         Returns:
             (None) or (PIL.Image): If `pil` is True, a PIL Image is returned. Otherwise, nothing is returned.
         """
-        img = deepcopy(self.orig_img)
-        annotator = Annotator(img, line_width, font_size, font, pil, example)
+        annotator = Annotator(deepcopy(self.orig_img), line_width, font_size, font, pil, example)
         boxes = self.boxes
         masks = self.masks
         logits = self.probs
@@ -136,7 +135,7 @@ class Results:
                 annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
 
         if masks is not None:
-            im = torch.as_tensor(img, dtype=torch.float16, device=masks.data.device).permute(2, 0, 1).flip(0)
+            im = torch.as_tensor(annotator.im, dtype=torch.float16, device=masks.data.device).permute(2, 0, 1).flip(0)
             im = F.resize(im.contiguous(), masks.data.shape[1:]) / 255
             annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im)
 
@@ -146,7 +145,7 @@ class Results:
             text = f"{', '.join(f'{names[j] if names else j} {logits[j]:.2f}' for j in top5i)}, "
             annotator.text((32, 32), text, txt_color=(255, 255, 255))  # TODO: allow setting colors
 
-        return img
+        return np.asarray(annotator.im) if annotator.pil else annotator.im
 
 
 class Boxes:
diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py
index 27abb7f2..55159c38 100644
--- a/ultralytics/yolo/engine/trainer.py
+++ b/ultralytics/yolo/engine/trainer.py
@@ -197,12 +197,15 @@ class BaseTrainer:
         """
         Builds dataloaders and optimizer on correct rank process.
         """
-        # model
+        # Model
         self.run_callbacks('on_pretrain_routine_start')
         ckpt = self.setup_model()
         self.model = self.model.to(self.device)
         self.set_model_attributes()
+        # Check AMP
+        callbacks_backup = callbacks.default_callbacks.copy()  # backup callbacks as they are reset by check_amp()
         self.amp = check_amp(self.model)
+        callbacks.default_callbacks = callbacks_backup  # restore callbacks
         self.scaler = amp.GradScaler(enabled=self.amp)
         if world_size > 1:
             self.model = DDP(self.model, device_ids=[rank])
@@ -610,7 +613,7 @@ def check_amp(model):
         a = m(im, device=device, verbose=False)[0].boxes.boxes  # FP32 inference
         with torch.cuda.amp.autocast(True):
             b = m(im, device=device, verbose=False)[0].boxes.boxes  # AMP inference
-        return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.1)  # close to 10% absolute tolerance
+        return a.shape == b.shape and torch.allclose(a, b.float(), rtol=0.1)  # close to 10% absolute tolerance
 
     f = ROOT / 'assets/bus.jpg'  # image to check
     im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3))
diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py
index d95837d5..5b4dfdda 100644
--- a/ultralytics/yolo/utils/__init__.py
+++ b/ultralytics/yolo/utils/__init__.py
@@ -17,7 +17,6 @@ from typing import Union
 
 import cv2
 import numpy as np
-import pandas as pd
 import torch
 import yaml
 
@@ -95,8 +94,6 @@ HELP_MSG = \
 # Settings
 torch.set_printoptions(linewidth=320, precision=5, profile='long')
 np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format})  # format short g, %precision=5
-pd.options.display.max_columns = 10
-pd.options.display.width = 120
 cv2.setNumThreads(0)  # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
 os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS)  # NumExpr max threads
 os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'  # for deterministic training
diff --git a/ultralytics/yolo/utils/benchmarks.py b/ultralytics/yolo/utils/benchmarks.py
index 2d091c87..4a16114b 100644
--- a/ultralytics/yolo/utils/benchmarks.py
+++ b/ultralytics/yolo/utils/benchmarks.py
@@ -26,8 +26,6 @@ import platform
 import time
 from pathlib import Path
 
-import pandas as pd
-
 from ultralytics import YOLO
 from ultralytics.yolo.engine.exporter import export_formats
 from ultralytics.yolo.utils import LINUX, LOGGER, ROOT, SETTINGS
@@ -38,6 +36,9 @@ from ultralytics.yolo.utils.torch_utils import select_device
 
 
 def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, half=False, device='cpu', hard_fail=False):
+    import pandas as pd
+    pd.options.display.max_columns = 10
+    pd.options.display.width = 120
     device = select_device(device, verbose=False)
     if isinstance(model, (str, Path)):
         model = YOLO(model)
diff --git a/ultralytics/yolo/utils/plotting.py b/ultralytics/yolo/utils/plotting.py
index 33c95d6a..b764ffcf 100644
--- a/ultralytics/yolo/utils/plotting.py
+++ b/ultralytics/yolo/utils/plotting.py
@@ -8,7 +8,6 @@ import cv2
 import matplotlib
 import matplotlib.pyplot as plt
 import numpy as np
-import pandas as pd
 import torch
 from PIL import Image, ImageDraw, ImageFont
 from PIL import __version__ as pil_version
@@ -160,6 +159,7 @@ class Annotator:
 
 @TryExcept()  # known issue https://github.com/ultralytics/yolov5/issues/5395
 def plot_labels(boxes, cls, names=(), save_dir=Path('')):
+    import pandas as pd
     import seaborn as sn
 
     # plot dataset labels
@@ -275,7 +275,7 @@ def plot_images(images,
         x, y = int(w * (i // ns)), int(h * (i % ns))  # block origin
         annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2)  # borders
         if paths:
-            annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220))  # filenames
+            annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220))  # filenames
         if len(cls) > 0:
             idx = batch_idx == i
 
@@ -330,6 +330,7 @@ def plot_images(images,
 
 def plot_results(file='path/to/results.csv', dir='', segment=False):
     # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
+    import pandas as pd
     save_dir = Path(file).parent if file else Path(dir)
     if segment:
         fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
diff --git a/ultralytics/yolo/utils/tal.py b/ultralytics/yolo/utils/tal.py
index cf48a2f6..0b714144 100644
--- a/ultralytics/yolo/utils/tal.py
+++ b/ultralytics/yolo/utils/tal.py
@@ -10,7 +10,7 @@ from .metrics import bbox_iou
 TORCH_1_10 = check_version(torch.__version__, '1.10.0')
 
 
-def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9, roll_out=False):
+def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
     """select the positive anchor center in gt
 
     Args:
@@ -21,18 +21,10 @@ def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9, roll_out=False):
     """
     n_anchors = xy_centers.shape[0]
     bs, n_boxes, _ = gt_bboxes.shape
-    if roll_out:
-        bbox_deltas = torch.empty((bs, n_boxes, n_anchors), device=gt_bboxes.device)
-        for b in range(bs):
-            lt, rb = gt_bboxes[b].view(-1, 1, 4).chunk(2, 2)  # left-top, right-bottom
-            bbox_deltas[b] = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]),
-                                       dim=2).view(n_boxes, n_anchors, -1).amin(2).gt_(eps)
-        return bbox_deltas
-    else:
-        lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)  # left-top, right-bottom
-        bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
-        # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
-        return bbox_deltas.amin(3).gt_(eps)
+    lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)  # left-top, right-bottom
+    bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
+    # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
+    return bbox_deltas.amin(3).gt_(eps)
 
 
 def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
@@ -63,7 +55,7 @@ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
 
 class TaskAlignedAssigner(nn.Module):
 
-    def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9, roll_out_thr=0):
+    def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
         super().__init__()
         self.topk = topk
         self.num_classes = num_classes
@@ -71,7 +63,6 @@ class TaskAlignedAssigner(nn.Module):
         self.alpha = alpha
         self.beta = beta
         self.eps = eps
-        self.roll_out_thr = roll_out_thr
 
     @torch.no_grad()
     def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
@@ -93,7 +84,6 @@ class TaskAlignedAssigner(nn.Module):
         """
         self.bs = pd_scores.size(0)
         self.n_max_boxes = gt_bboxes.size(1)
-        self.roll_out = self.n_max_boxes > self.roll_out_thr if self.roll_out_thr else False
 
         if self.n_max_boxes == 0:
             device = gt_bboxes.device
@@ -119,40 +109,35 @@ class TaskAlignedAssigner(nn.Module):
         return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
 
     def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
-        # get anchor_align metric, (b, max_num_obj, h*w)
-        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes)
         # get in_gts mask, (b, max_num_obj, h*w)
-        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes, roll_out=self.roll_out)
+        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
+        # get anchor_align metric, (b, max_num_obj, h*w)
+        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
         # get topk_metric mask, (b, max_num_obj, h*w)
-        mask_topk = self.select_topk_candidates(align_metric * mask_in_gts,
-                                                topk_mask=mask_gt.repeat([1, 1, self.topk]).bool())
+        mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.repeat([1, 1, self.topk]).bool())
         # merge all mask to a final mask, (b, max_num_obj, h*w)
         mask_pos = mask_topk * mask_in_gts * mask_gt
 
         return mask_pos, align_metric, overlaps
 
-    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes):
-        if self.roll_out:
-            align_metric = torch.empty((self.bs, self.n_max_boxes, pd_scores.shape[1]), device=pd_scores.device)
-            overlaps = torch.empty((self.bs, self.n_max_boxes, pd_scores.shape[1]), device=pd_scores.device)
-            ind_0 = torch.empty(self.n_max_boxes, dtype=torch.long)
-            for b in range(self.bs):
-                ind_0[:], ind_2 = b, gt_labels[b].squeeze(-1).long()
-                # get the scores of each grid for each gt cls
-                bbox_scores = pd_scores[ind_0, :, ind_2]  # b, max_num_obj, h*w
-                overlaps[b] = bbox_iou(gt_bboxes[b].unsqueeze(1), pd_bboxes[b].unsqueeze(0), xywh=False,
-                                       CIoU=True).squeeze(2).clamp(0)
-                align_metric[b] = bbox_scores.pow(self.alpha) * overlaps[b].pow(self.beta)
-        else:
-            ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
-            ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes)  # b, max_num_obj
-            ind[1] = gt_labels.long().squeeze(-1)  # b, max_num_obj
-            # get the scores of each grid for each gt cls
-            bbox_scores = pd_scores[ind[0], :, ind[1]]  # b, max_num_obj, h*w
+    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
+        na = pd_bboxes.shape[-2]
+        mask_gt = mask_gt.bool()  # b, max_num_obj, h*w
+        overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
+        bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
 
-            overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False,
-                                CIoU=True).squeeze(3).clamp(0)
-            align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
+        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
+        ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes)  # b, max_num_obj
+        ind[1] = gt_labels.long().squeeze(-1)  # b, max_num_obj
+        # get the scores of each grid for each gt cls
+        bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt]  # b, max_num_obj, h*w
+
+        # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
+        pd_boxes = pd_bboxes.unsqueeze(1).repeat(1, self.n_max_boxes, 1, 1)[mask_gt]
+        gt_boxes = gt_bboxes.unsqueeze(2).repeat(1, 1, na, 1)[mask_gt]
+        overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp(0)
+
+        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
         return align_metric, overlaps
 
     def select_topk_candidates(self, metrics, largest=True, topk_mask=None):
@@ -170,12 +155,10 @@ class TaskAlignedAssigner(nn.Module):
         # (b, max_num_obj, topk)
         topk_idxs[~topk_mask] = 0
         # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
-        if self.roll_out:
-            is_in_topk = torch.empty(metrics.shape, dtype=torch.long, device=metrics.device)
-            for b in range(len(topk_idxs)):
-                is_in_topk[b] = F.one_hot(topk_idxs[b], num_anchors).sum(-2)
-        else:
-            is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
+        is_in_topk = torch.zeros(metrics.shape, dtype=torch.long, device=metrics.device)
+        for it in range(self.topk):
+            is_in_topk += F.one_hot(topk_idxs[:, :, it], num_anchors)
+        # is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
         # filter invalid bboxes
         is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
         return is_in_topk.to(metrics.dtype)
diff --git a/ultralytics/yolo/v8/classify/val.py b/ultralytics/yolo/v8/classify/val.py
index 0a2afb68..9eed5807 100644
--- a/ultralytics/yolo/v8/classify/val.py
+++ b/ultralytics/yolo/v8/classify/val.py
@@ -33,6 +33,7 @@ class ClassificationValidator(BaseValidator):
 
     def finalize_metrics(self, *args, **kwargs):
         self.metrics.speed = self.speed
+        # self.metrics.confusion_matrix = self.confusion_matrix  # TODO: classification ConfusionMatrix
 
     def get_stats(self):
         self.metrics.process(self.targets, self.pred)
diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py
index 98ab7aea..a3e8f21b 100644
--- a/ultralytics/yolo/v8/detect/train.py
+++ b/ultralytics/yolo/v8/detect/train.py
@@ -124,13 +124,8 @@ class Loss:
         self.device = device
 
         self.use_dfl = m.reg_max > 1
-        roll_out_thr = h.min_memory if h.min_memory > 1 else 64 if h.min_memory else 0  # 64 is default
 
-        self.assigner = TaskAlignedAssigner(topk=10,
-                                            num_classes=self.nc,
-                                            alpha=0.5,
-                                            beta=6.0,
-                                            roll_out_thr=roll_out_thr)
+        self.assigner = TaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
         self.bbox_loss = BboxLoss(m.reg_max - 1, use_dfl=self.use_dfl).to(device)
         self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
 
diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py
index 116c6308..b4bc45e0 100644
--- a/ultralytics/yolo/v8/detect/val.py
+++ b/ultralytics/yolo/v8/detect/val.py
@@ -113,6 +113,7 @@ class DetectionValidator(BaseValidator):
 
     def finalize_metrics(self, *args, **kwargs):
         self.metrics.speed = self.speed
+        self.metrics.confusion_matrix = self.confusion_matrix
 
     def get_stats(self):
         stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)]  # to numpy
diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py
index 26759f12..403f8800 100644
--- a/ultralytics/yolo/v8/segment/val.py
+++ b/ultralytics/yolo/v8/segment/val.py
@@ -121,6 +121,7 @@ class SegmentationValidator(DetectionValidator):
 
     def finalize_metrics(self, *args, **kwargs):
         self.metrics.speed = self.speed
+        self.metrics.confusion_matrix = self.confusion_matrix
 
     def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
         """