From 16ce193d6e4ae4718a2ce20fa9d716c5464fc730 Mon Sep 17 00:00:00 2001
From: Glenn Jocher <glenn.jocher@ultralytics.com>
Date: Sun, 10 Sep 2023 03:27:23 +0200
Subject: [PATCH] `ultralytics 8.0.174` Tuner plots and improvements (#4799)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
 docs/reference/utils/__init__.md |   4 ++
 docs/reference/utils/plotting.md |   8 +++
 tests/test_cuda.py               |  21 +++---
 ultralytics/__init__.py          |   2 +-
 ultralytics/cfg/__init__.py      |   2 +-
 ultralytics/engine/model.py      |   2 +-
 ultralytics/engine/trainer.py    |  27 ++++----
 ultralytics/engine/tuner.py      | 111 ++++++++++++++++++-------------
 ultralytics/hub/session.py       |   3 +-
 ultralytics/utils/__init__.py    |  46 ++++++++++++-
 ultralytics/utils/plotting.py    | 100 +++++++++++++++++++++++++++-
 ultralytics/utils/torch_utils.py |  10 +--
 12 files changed, 248 insertions(+), 88 deletions(-)

diff --git a/docs/reference/utils/__init__.md b/docs/reference/utils/__init__.md
index 31c22ad1..7c32363f 100644
--- a/docs/reference/utils/__init__.md
+++ b/docs/reference/utils/__init__.md
@@ -133,6 +133,10 @@ keywords: Ultralytics, Utils, utilitarian functions, colorstr, yaml_save, set_lo
 ## ::: ultralytics.utils.colorstr
 <br><br>
 
+---
+## ::: ultralytics.utils.remove_colorstr
+<br><br>
+
 ---
 ## ::: ultralytics.utils.threaded
 <br><br>
diff --git a/docs/reference/utils/plotting.md b/docs/reference/utils/plotting.md
index a0c28b13..15414eb5 100644
--- a/docs/reference/utils/plotting.md
+++ b/docs/reference/utils/plotting.md
@@ -33,6 +33,14 @@ keywords: Ultralytics, plotting, utils, color annotation, label plotting, image
 ## ::: ultralytics.utils.plotting.plot_results
 <br><br>
 
+---
+## ::: ultralytics.utils.plotting.plt_color_scatter
+<br><br>
+
+---
+## ::: ultralytics.utils.plotting.plot_tune_results
+<br><br>
+
 ---
 ## ::: ultralytics.utils.plotting.output_to_target
 <br><br>
diff --git a/tests/test_cuda.py b/tests/test_cuda.py
index 44dc159c..7063271b 100644
--- a/tests/test_cuda.py
+++ b/tests/test_cuda.py
@@ -94,8 +94,8 @@ def test_model_ray_tune():
 
 @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
 def test_model_tune():
-    YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', imgsz=32, epochs=1, iterations=2, device='cpu')
-    YOLO('yolov8n-cls.pt').tune(data='imagenet10', imgsz=32, epochs=1, iterations=2, device='cpu')
+    YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
+    YOLO('yolov8n-cls.pt').tune(data='imagenet10', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
 
 
 @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
@@ -107,27 +107,22 @@ def test_pycocotools():
     # Download annotations after each dataset downloads first
     url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/'
 
-    validator = DetectionValidator(args={'model': 'yolov8n.pt', 'data': 'coco8.yaml', 'save_json': True, 'imgsz': 64})
+    args = {'model': 'yolov8n.pt', 'data': 'coco8.yaml', 'save_json': True, 'imgsz': 64}
+    validator = DetectionValidator(args=args)
     validator()
     validator.is_coco = True
     download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8/annotations')
     _ = validator.eval_json(validator.stats)
 
-    validator = SegmentationValidator(args={
-        'model': 'yolov8n-seg.pt',
-        'data': 'coco8-seg.yaml',
-        'save_json': True,
-        'imgsz': 64})
+    args = {'model': 'yolov8n-seg.pt', 'data': 'coco8-seg.yaml', 'save_json': True, 'imgsz': 64}
+    validator = SegmentationValidator(args=args)
     validator()
     validator.is_coco = True
     download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8-seg/annotations')
     _ = validator.eval_json(validator.stats)
 
-    validator = PoseValidator(args={
-        'model': 'yolov8n-pose.pt',
-        'data': 'coco8-pose.yaml',
-        'save_json': True,
-        'imgsz': 64})
+    args = {'model': 'yolov8n-pose.pt', 'data': 'coco8-pose.yaml', 'save_json': True, 'imgsz': 64}
+    validator = PoseValidator(args=args)
     validator()
     validator.is_coco = True
     download(f'{url}person_keypoints_val2017.json', dir=DATASETS_DIR / 'coco8-pose/annotations')
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index a0ae59fe..7a9f7f8c 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
 # Ultralytics YOLO 🚀, AGPL-3.0 license
 
-__version__ = '8.0.173'
+__version__ = '8.0.174'
 
 from ultralytics.models import RTDETR, SAM, YOLO
 from ultralytics.models.fastsam import FastSAM
diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py
index 7bc48f27..fadbb372 100644
--- a/ultralytics/cfg/__init__.py
+++ b/ultralytics/cfg/__init__.py
@@ -358,7 +358,7 @@ def entrypoint(debug=''):
         if '=' in a:
             try:
                 k, v = parse_key_value_pair(a)
-                if k == 'cfg':  # custom.yaml passed
+                if k == 'cfg' and v is not None:  # custom.yaml passed
                     LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
                     overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
                 else:
diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py
index ac573236..0dbdd64a 100644
--- a/ultralytics/engine/model.py
+++ b/ultralytics/engine/model.py
@@ -361,7 +361,7 @@ class Model:
         else:
             from .tuner import Tuner
 
-            custom = {'plots': False, 'save': False}  # method defaults
+            custom = {}  # method defaults
             args = {**self.overrides, **custom, **kwargs, 'mode': 'train'}  # highest priority args on the right
             return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
 
diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py
index 4ff4229a..9ccb3f36 100644
--- a/ultralytics/engine/trainer.py
+++ b/ultralytics/engine/trainer.py
@@ -423,7 +423,10 @@ class BaseTrainer:
         self.run_callbacks('teardown')
 
     def save_model(self):
-        """Save model checkpoints based on various conditions."""
+        """Save model training checkpoints with additional metadata."""
+        import pandas as pd  # scope for faster startup
+        metrics = {**self.metrics, **{'fitness': self.fitness}}
+        results = {k.strip(): v for k, v in pd.read_csv(self.save_dir / 'results.csv').to_dict(orient='list').items()}
         ckpt = {
             'epoch': self.epoch,
             'best_fitness': self.best_fitness,
@@ -432,22 +435,17 @@ class BaseTrainer:
             'updates': self.ema.updates,
             'optimizer': self.optimizer.state_dict(),
             'train_args': vars(self.args),  # save as dict
+            'train_metrics': metrics,
+            'train_results': results,
             'date': datetime.now().isoformat(),
             'version': __version__}
 
-        # Use dill (if exists) to serialize the lambda functions where pickle does not do this
-        try:
-            import dill as pickle
-        except ImportError:
-            import pickle
-
-        # Save last, best and delete
-        torch.save(ckpt, self.last, pickle_module=pickle)
+        # Save last and best
+        torch.save(ckpt, self.last)
         if self.best_fitness == self.fitness:
-            torch.save(ckpt, self.best, pickle_module=pickle)
-        if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):
-            torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)
-        del ckpt
+            torch.save(ckpt, self.best)
+        if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
+            torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt')
 
     @staticmethod
     def get_dataset(data):
@@ -654,6 +652,9 @@ class BaseTrainer:
         g = [], [], []  # optimizer parameter groups
         bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)  # normalization layers, i.e. BatchNorm2d()
         if name == 'auto':
+            LOGGER.info(f"{colorstr('optimizer:')} 'optimizer=auto' found, "
+                        f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
+                        f"determining best 'optimizer', 'lr0' and 'momentum' automatically... ")
             nc = getattr(model, 'nc', 10)  # number of classes
             lr_fit = round(0.002 * 5 / (4 + nc), 6)  # lr0 fit equation to 6 decimal places
             name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
diff --git a/ultralytics/engine/tuner.py b/ultralytics/engine/tuner.py
index 0702690d..d60a56c2 100644
--- a/ultralytics/engine/tuner.py
+++ b/ultralytics/engine/tuner.py
@@ -13,18 +13,20 @@ Example:
     from ultralytics import YOLO
 
     model = YOLO('yolov8n.pt')
-    model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10)
+    model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
     ```
 """
 import random
+import shutil
+import subprocess
 import time
-from copy import deepcopy
 
 import numpy as np
+import torch
 
-from ultralytics import YOLO
 from ultralytics.cfg import get_cfg, get_save_dir
-from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, yaml_print, yaml_save
+from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, remove_colorstr, yaml_print, yaml_save
+from ultralytics.utils.plotting import plot_tune_results
 
 
 class Tuner:
@@ -37,7 +39,7 @@ class Tuner:
      Attributes:
          space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
          tune_dir (Path): Directory where evolution logs and results will be saved.
-         evolve_csv (Path): Path to the CSV file where evolution logs are saved.
+         tune_csv (Path): Path to the CSV file where evolution logs are saved.
 
      Methods:
          _mutate(hyp: dict) -> dict:
@@ -52,7 +54,7 @@ class Tuner:
          from ultralytics import YOLO
 
          model = YOLO('yolov8n.pt')
-         model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10, val=False, cache=True)
+         model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
          ```
      """
 
@@ -64,22 +66,23 @@ class Tuner:
             args (dict, optional): Configuration for hyperparameter evolution.
         """
         self.args = get_cfg(overrides=args)
-        self.space = {  # key: (min, max, gain(optionaL))
+        self.space = {  # key: (min, max, gain(optional))
             # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
             'lr0': (1e-5, 1e-1),
-            'lrf': (0.01, 1.0),  # final OneCycleLR learning rate (lr0 * lrf)
-            'momentum': (0.6, 0.98, 0.3),  # SGD momentum/Adam beta1
+            'lrf': (0.0001, 0.1),  # final OneCycleLR learning rate (lr0 * lrf)
+            'momentum': (0.7, 0.98, 0.3),  # SGD momentum/Adam beta1
             'weight_decay': (0.0, 0.001),  # optimizer weight decay 5e-4
             'warmup_epochs': (0.0, 5.0),  # warmup epochs (fractions ok)
             'warmup_momentum': (0.0, 0.95),  # warmup initial momentum
-            'box': (0.02, 0.2),  # box loss gain
+            'box': (1.0, 20.0),  # box loss gain
             'cls': (0.2, 4.0),  # cls loss gain (scale with pixels)
+            'dfl': (0.4, 6.0),  # dfl loss gain
             'hsv_h': (0.0, 0.1),  # image HSV-Hue augmentation (fraction)
             'hsv_s': (0.0, 0.9),  # image HSV-Saturation augmentation (fraction)
             'hsv_v': (0.0, 0.9),  # image HSV-Value augmentation (fraction)
             'degrees': (0.0, 45.0),  # image rotation (+/- deg)
             'translate': (0.0, 0.9),  # image translation (+/- fraction)
-            'scale': (0.0, 0.9),  # image scale (+/- gain)
+            'scale': (0.0, 0.95),  # image scale (+/- gain)
             'shear': (0.0, 10.0),  # image shear (+/- deg)
             'perspective': (0.0, 0.001),  # image perspective (+/- fraction), range 0-0.001
             'flipud': (0.0, 1.0),  # image flip up-down (probability)
@@ -87,11 +90,13 @@ class Tuner:
             'mosaic': (0.0, 1.0),  # image mixup (probability)
             'mixup': (0.0, 1.0),  # image mixup (probability)
             'copy_paste': (0.0, 1.0)}  # segment copy-paste (probability)
-        self.tune_dir = get_save_dir(self.args, name='_tune')
-        self.evolve_csv = self.tune_dir / 'evolve.csv'
+        self.tune_dir = get_save_dir(self.args, name='tune')
+        self.tune_csv = self.tune_dir / 'tune_results.csv'
         self.callbacks = _callbacks or callbacks.get_default_callbacks()
+        self.prefix = colorstr('Tuner: ')
         callbacks.add_integration_callbacks(self)
-        LOGGER.info(f"Initialized Tuner instance with 'tune_dir={self.tune_dir}'.")
+        LOGGER.info(f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
+                    f'{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning')
 
     def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.2):
         """
@@ -106,9 +111,9 @@ class Tuner:
         Returns:
             (dict): A dictionary containing mutated hyperparameters.
         """
-        if self.evolve_csv.exists():  # if evolve.csv exists: select best hyps and mutate
+        if self.tune_csv.exists():  # if CSV file exists: select best hyps and mutate
             # Select parent(s)
-            x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1)
+            x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1)
             fitness = x[:, 0]  # first column
             n = min(n, len(x))  # number of previous results to consider
             x = x[np.argsort(-fitness)][:n]  # top n mutations
@@ -139,7 +144,7 @@ class Tuner:
 
         return hyp
 
-    def __call__(self, model=None, iterations=10, prefix=colorstr('Tuner:')):
+    def __call__(self, model=None, iterations=10, cleanup=True):
         """
         Executes the hyperparameter evolution process when the Tuner instance is called.
 
@@ -152,54 +157,68 @@ class Tuner:
         Args:
            model (Model): A pre-initialized YOLO model to be used for training.
            iterations (int): The number of generations to run the evolution for.
+           cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning.
 
         Note:
-           The method utilizes the `self.evolve_csv` Path object to read and log hyperparameters and fitness scores.
+           The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
            Ensure this path is set correctly in the Tuner instance.
         """
 
         t0 = time.time()
         best_save_dir, best_metrics = None, None
-        self.tune_dir.mkdir(parents=True, exist_ok=True)
+        (self.tune_dir / 'weights').mkdir(parents=True, exist_ok=True)
         for i in range(iterations):
             # Mutate hyperparameters
             mutated_hyp = self._mutate()
-            LOGGER.info(f'{prefix} Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}')
+            LOGGER.info(f'{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}')
 
+            metrics = {}
+            train_args = {**vars(self.args), **mutated_hyp}
+            save_dir = get_save_dir(get_cfg(train_args))
             try:
-                # Train YOLO model with mutated hyperparameters
-                train_args = {**vars(self.args), **mutated_hyp}
-                results = (deepcopy(model) or YOLO(self.args.model)).train(**train_args)
-                fitness = results.fitness
-            except Exception as e:
-                LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i}\n{e}')
-                fitness = 0.0
+                # Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
+                weights_dir = save_dir / 'weights'
+                cmd = ['yolo', 'train', *(f'{k}={v}' for k, v in train_args.items())]
+                assert subprocess.run(cmd, check=True).returncode == 0, 'training failed'
+                ckpt_file = weights_dir / ('best.pt' if (weights_dir / 'best.pt').exists() else 'last.pt')
+                metrics = torch.load(ckpt_file)['train_metrics']
 
-            # Save results and mutated_hyp to evolve_csv
+            except Exception as e:
+                LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}')
+
+            # Save results and mutated_hyp to CSV
+            fitness = metrics.get('fitness', 0.0)
             log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
-            headers = '' if self.evolve_csv.exists() else (','.join(['fitness_score'] + list(self.space.keys())) + '\n')
-            with open(self.evolve_csv, 'a') as f:
+            headers = '' if self.tune_csv.exists() else (','.join(['fitness'] + list(self.space.keys())) + '\n')
+            with open(self.tune_csv, 'a') as f:
                 f.write(headers + ','.join(map(str, log_row)) + '\n')
 
-            # Print tuning results
-            x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1)
+            # Get best results
+            x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1)
             fitness = x[:, 0]  # first column
             best_idx = fitness.argmax()
             best_is_current = best_idx == i
             if best_is_current:
-                best_save_dir = results.save_dir
-                best_metrics = {k: round(v, 5) for k, v in results.results_dict.items()}
-            header = (f'{prefix} {i + 1} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
-                      f'{prefix} Results saved to {colorstr("bold", self.tune_dir)}\n'
-                      f'{prefix} Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
-                      f'{prefix} Best fitness metrics are {best_metrics}\n'
-                      f'{prefix} Best fitness model is {best_save_dir}\n'
-                      f'{prefix} Best fitness hyperparameters are printed below.\n')
+                best_save_dir = save_dir
+                best_metrics = {k: round(v, 5) for k, v in metrics.items()}
+                for ckpt in weights_dir.glob('*.pt'):
+                    shutil.copy2(ckpt, self.tune_dir / 'weights')
+            elif cleanup:
+                shutil.rmtree(ckpt_file.parent)  # remove iteration weights/ dir to reduce storage space
 
+            # Plot tune results
+            plot_tune_results(self.tune_csv)
+
+            # Save and print tune results
+            header = (f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
+                      f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n'
+                      f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
+                      f'{self.prefix}Best fitness metrics are {best_metrics}\n'
+                      f'{self.prefix}Best fitness model is {best_save_dir}\n'
+                      f'{self.prefix}Best fitness hyperparameters are printed below.\n')
             LOGGER.info('\n' + header)
-
-            # Save turning results
-            data = {k: float(x[0, i + 1]) for i, k in enumerate(self.space.keys())}
-            header = header.replace(prefix, '#').replace('/', '').replace('', '') + '\n'
-            yaml_save(self.tune_dir / 'best.yaml', data=data, header=header)
-            yaml_print(self.tune_dir / 'best.yaml')
+            data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
+            yaml_save(self.tune_dir / 'best_hyperparameters.yaml',
+                      data=data,
+                      header=remove_colorstr(header.replace(self.prefix, '# ')) + '\n')
+            yaml_print(self.tune_dir / 'best_hyperparameters.yaml')
diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py
index 595de297..396f9b37 100644
--- a/ultralytics/hub/session.py
+++ b/ultralytics/hub/session.py
@@ -117,8 +117,7 @@ class HUBTrainingSession:
 
             if data['status'] == 'new':  # new model to start training
                 self.train_args = {
-                    # TODO: deprecate 'batch_size' key for 'batch' in 3Q23
-                    'batch': data['batch' if ('batch' in data) else 'batch_size'],
+                    'batch': data['batch'],
                     'epochs': data['epochs'],
                     'imgsz': data['imgsz'],
                     'patience': data['patience'],
diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py
index 872ce5fd..bdf981a1 100644
--- a/ultralytics/utils/__init__.py
+++ b/ultralytics/utils/__init__.py
@@ -635,7 +635,33 @@ SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml'
 
 
 def colorstr(*input):
-    """Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e.  colorstr('blue', 'hello world')."""
+    """
+    Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes.
+    See https://en.wikipedia.org/wiki/ANSI_escape_code for more details.
+
+    This function can be called in two ways:
+        - colorstr('color', 'style', 'your string')
+        - colorstr('your string')
+
+    In the second form, 'blue' and 'bold' will be applied by default.
+
+    Args:
+        *input (str): A sequence of strings where the first n-1 strings are color and style arguments,
+                      and the last string is the one to be colored.
+
+    Supported Colors and Styles:
+        Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'
+        Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow',
+                       'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white'
+        Misc: 'end', 'bold', 'underline'
+
+    Returns:
+        (str): The input string wrapped with ANSI escape codes for the specified color and style.
+
+    Examples:
+        >>> colorstr('blue', 'bold', 'hello world')
+        >>> '\033[34m\033[1mhello world\033[0m'
+    """
     *args, string = input if len(input) > 1 else ('blue', 'bold', input[0])  # color arguments, string
     colors = {
         'black': '\033[30m',  # basic colors
@@ -660,6 +686,24 @@ def colorstr(*input):
     return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
 
 
+def remove_colorstr(input_string):
+    """
+    Removes ANSI escape codes from a string, effectively un-coloring it.
+
+    Args:
+        input_string (str): The string to remove color and style from.
+
+    Returns:
+        (str): A new string with all ANSI escape codes removed.
+
+    Examples:
+        >>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
+        >>> 'hello world'
+    """
+    ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
+    return ansi_escape.sub('', input_string)
+
+
 class TryExcept(contextlib.ContextDecorator):
     """YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager."""
 
diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py
index 6237f133..3d70d3c7 100644
--- a/ultralytics/utils/plotting.py
+++ b/ultralytics/utils/plotting.py
@@ -498,13 +498,23 @@ def plot_images(images,
 @plt_settings()
 def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
     """
-    Plot training results from results CSV file.
+    Plot training results from a results CSV file. The function supports various types of data including segmentation,
+    pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
+
+    Args:
+        file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'.
+        dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''.
+        segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False.
+        pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False.
+        classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False.
+        on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
+            Defaults to None.
 
     Example:
         ```python
         from ultralytics.utils.plotting import plot_results
 
-        plot_results('path/to/results.csv')
+        plot_results('path/to/results.csv', segment=True)
         ```
     """
     import pandas as pd
@@ -548,6 +558,92 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
         on_plot(fname)
 
 
+def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none'):
+    """
+    Plots a scatter plot with points colored based on a 2D histogram.
+
+    Args:
+        v (array-like): Values for the x-axis.
+        f (array-like): Values for the y-axis.
+        bins (int, optional): Number of bins for the histogram. Defaults to 20.
+        cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'.
+        alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8.
+        edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'.
+
+    Examples:
+        >>> v = np.random.rand(100)
+        >>> f = np.random.rand(100)
+        >>> plt_color_scatter(v, f)
+    """
+
+    # Calculate 2D histogram and corresponding colors
+    hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
+    colors = [
+        hist[min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
+             min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1)] for i in range(len(v))]
+
+    # Scatter plot
+    plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
+
+
+def plot_tune_results(csv_file='tune_results.csv'):
+    """
+    Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key
+    in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
+
+    Args:
+        csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.
+
+    Examples:
+        >>> plot_tune_results('path/to/tune_results.csv')
+    """
+
+    import pandas as pd
+    from scipy.ndimage import gaussian_filter1d
+
+    # Scatter plots for each hyperparameter
+    csv_file = Path(csv_file)
+    data = pd.read_csv(csv_file)
+    num_metrics_columns = 1
+    keys = [x.strip() for x in data.columns][num_metrics_columns:]
+    x = data.values
+    fitness = x[:, 0]  # fitness
+    j = np.argmax(fitness)  # max fitness index
+    n = math.ceil(len(keys) ** 0.5)  # columns and rows in plot
+    plt.figure(figsize=(10, 10), tight_layout=True)
+    for i, k in enumerate(keys):
+        v = x[:, i + num_metrics_columns]
+        mu = v[j]  # best single result
+        plt.subplot(n, n, i + 1)
+        plt_color_scatter(v, fitness, cmap='viridis', alpha=.8, edgecolors='none')
+        plt.plot(mu, fitness.max(), 'k+', markersize=15)
+        plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9})  # limit to 40 characters
+        plt.tick_params(axis='both', labelsize=8)  # Set axis label size to 8
+        if i % n != 0:
+            plt.yticks([])
+
+    file = csv_file.with_name('tune_scatter_plots.png')  # filename
+    plt.savefig(file, dpi=200)
+    plt.close()
+    LOGGER.info(f'Saved {file}')
+
+    # Fitness vs iteration
+    x = range(1, len(fitness) + 1)
+    plt.figure(figsize=(10, 6), tight_layout=True)
+    plt.plot(x, fitness, marker='o', linestyle='none', label='fitness')
+    plt.plot(x, gaussian_filter1d(fitness, sigma=3), ':', label='smoothed', linewidth=2)  # smoothing line
+    plt.title('Fitness vs Iteration')
+    plt.xlabel('Iteration')
+    plt.ylabel('Fitness')
+    plt.grid(True)
+    plt.legend()
+
+    file = csv_file.with_name('tune_fitness.png')  # filename
+    plt.savefig(file, dpi=200)
+    plt.close()
+    LOGGER.info(f'Saved {file}')
+
+
 def output_to_target(output, max_det=300):
     """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
     targets = []
diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py
index def7442b..dcab101c 100644
--- a/ultralytics/utils/torch_utils.py
+++ b/ultralytics/utils/torch_utils.py
@@ -395,12 +395,6 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
             strip_optimizer(f)
         ```
     """
-    # Use dill (if exists) to serialize the lambda functions where pickle does not do this
-    try:
-        import dill as pickle
-    except ImportError:
-        import pickle
-
     x = torch.load(f, map_location=torch.device('cpu'))
     if 'model' not in x:
         LOGGER.info(f'Skipping {f}, not a valid Ultralytics model.')
@@ -419,8 +413,8 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
         p.requires_grad = False
     x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS}  # strip non-default keys
     # x['model'].args = x['train_args']
-    torch.save(x, s or f, pickle_module=pickle)
-    mb = os.path.getsize(s or f) / 1E6  # filesize
+    torch.save(x, s or f)
+    mb = os.path.getsize(s or f) / 1E6  # file size
     LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")