diff --git a/docs/reference/utils/__init__.md b/docs/reference/utils/__init__.md
index 31c22ad1..7c32363f 100644
--- a/docs/reference/utils/__init__.md
+++ b/docs/reference/utils/__init__.md
@@ -133,6 +133,10 @@ keywords: Ultralytics, Utils, utilitarian functions, colorstr, yaml_save, set_lo
## ::: ultralytics.utils.colorstr
+---
+## ::: ultralytics.utils.remove_colorstr
+
+
---
## ::: ultralytics.utils.threaded
diff --git a/docs/reference/utils/plotting.md b/docs/reference/utils/plotting.md
index a0c28b13..15414eb5 100644
--- a/docs/reference/utils/plotting.md
+++ b/docs/reference/utils/plotting.md
@@ -33,6 +33,14 @@ keywords: Ultralytics, plotting, utils, color annotation, label plotting, image
## ::: ultralytics.utils.plotting.plot_results
+---
+## ::: ultralytics.utils.plotting.plt_color_scatter
+
+
+---
+## ::: ultralytics.utils.plotting.plot_tune_results
+
+
---
## ::: ultralytics.utils.plotting.output_to_target
diff --git a/tests/test_cuda.py b/tests/test_cuda.py
index 44dc159c..7063271b 100644
--- a/tests/test_cuda.py
+++ b/tests/test_cuda.py
@@ -94,8 +94,8 @@ def test_model_ray_tune():
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
def test_model_tune():
- YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', imgsz=32, epochs=1, iterations=2, device='cpu')
- YOLO('yolov8n-cls.pt').tune(data='imagenet10', imgsz=32, epochs=1, iterations=2, device='cpu')
+ YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
+ YOLO('yolov8n-cls.pt').tune(data='imagenet10', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
@@ -107,27 +107,22 @@ def test_pycocotools():
# Download annotations after each dataset downloads first
url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/'
- validator = DetectionValidator(args={'model': 'yolov8n.pt', 'data': 'coco8.yaml', 'save_json': True, 'imgsz': 64})
+ args = {'model': 'yolov8n.pt', 'data': 'coco8.yaml', 'save_json': True, 'imgsz': 64}
+ validator = DetectionValidator(args=args)
validator()
validator.is_coco = True
download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8/annotations')
_ = validator.eval_json(validator.stats)
- validator = SegmentationValidator(args={
- 'model': 'yolov8n-seg.pt',
- 'data': 'coco8-seg.yaml',
- 'save_json': True,
- 'imgsz': 64})
+ args = {'model': 'yolov8n-seg.pt', 'data': 'coco8-seg.yaml', 'save_json': True, 'imgsz': 64}
+ validator = SegmentationValidator(args=args)
validator()
validator.is_coco = True
download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8-seg/annotations')
_ = validator.eval_json(validator.stats)
- validator = PoseValidator(args={
- 'model': 'yolov8n-pose.pt',
- 'data': 'coco8-pose.yaml',
- 'save_json': True,
- 'imgsz': 64})
+ args = {'model': 'yolov8n-pose.pt', 'data': 'coco8-pose.yaml', 'save_json': True, 'imgsz': 64}
+ validator = PoseValidator(args=args)
validator()
validator.is_coco = True
download(f'{url}person_keypoints_val2017.json', dir=DATASETS_DIR / 'coco8-pose/annotations')
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index a0ae59fe..7a9f7f8c 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = '8.0.173'
+__version__ = '8.0.174'
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM
diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py
index 7bc48f27..fadbb372 100644
--- a/ultralytics/cfg/__init__.py
+++ b/ultralytics/cfg/__init__.py
@@ -358,7 +358,7 @@ def entrypoint(debug=''):
if '=' in a:
try:
k, v = parse_key_value_pair(a)
- if k == 'cfg': # custom.yaml passed
+ if k == 'cfg' and v is not None: # custom.yaml passed
LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
else:
diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py
index ac573236..0dbdd64a 100644
--- a/ultralytics/engine/model.py
+++ b/ultralytics/engine/model.py
@@ -361,7 +361,7 @@ class Model:
else:
from .tuner import Tuner
- custom = {'plots': False, 'save': False} # method defaults
+ custom = {} # method defaults
args = {**self.overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py
index 4ff4229a..9ccb3f36 100644
--- a/ultralytics/engine/trainer.py
+++ b/ultralytics/engine/trainer.py
@@ -423,7 +423,10 @@ class BaseTrainer:
self.run_callbacks('teardown')
def save_model(self):
- """Save model checkpoints based on various conditions."""
+ """Save model training checkpoints with additional metadata."""
+ import pandas as pd # scope for faster startup
+ metrics = {**self.metrics, **{'fitness': self.fitness}}
+ results = {k.strip(): v for k, v in pd.read_csv(self.save_dir / 'results.csv').to_dict(orient='list').items()}
ckpt = {
'epoch': self.epoch,
'best_fitness': self.best_fitness,
@@ -432,22 +435,17 @@ class BaseTrainer:
'updates': self.ema.updates,
'optimizer': self.optimizer.state_dict(),
'train_args': vars(self.args), # save as dict
+ 'train_metrics': metrics,
+ 'train_results': results,
'date': datetime.now().isoformat(),
'version': __version__}
- # Use dill (if exists) to serialize the lambda functions where pickle does not do this
- try:
- import dill as pickle
- except ImportError:
- import pickle
-
- # Save last, best and delete
- torch.save(ckpt, self.last, pickle_module=pickle)
+ # Save last and best
+ torch.save(ckpt, self.last)
if self.best_fitness == self.fitness:
- torch.save(ckpt, self.best, pickle_module=pickle)
- if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):
- torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)
- del ckpt
+ torch.save(ckpt, self.best)
+ if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
+ torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt')
@staticmethod
def get_dataset(data):
@@ -654,6 +652,9 @@ class BaseTrainer:
g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
if name == 'auto':
+ LOGGER.info(f"{colorstr('optimizer:')} 'optimizer=auto' found, "
+ f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
+ f"determining best 'optimizer', 'lr0' and 'momentum' automatically... ")
nc = getattr(model, 'nc', 10) # number of classes
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
diff --git a/ultralytics/engine/tuner.py b/ultralytics/engine/tuner.py
index 0702690d..d60a56c2 100644
--- a/ultralytics/engine/tuner.py
+++ b/ultralytics/engine/tuner.py
@@ -13,18 +13,20 @@ Example:
from ultralytics import YOLO
model = YOLO('yolov8n.pt')
- model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10)
+ model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
```
"""
import random
+import shutil
+import subprocess
import time
-from copy import deepcopy
import numpy as np
+import torch
-from ultralytics import YOLO
from ultralytics.cfg import get_cfg, get_save_dir
-from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, yaml_print, yaml_save
+from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, remove_colorstr, yaml_print, yaml_save
+from ultralytics.utils.plotting import plot_tune_results
class Tuner:
@@ -37,7 +39,7 @@ class Tuner:
Attributes:
space (dict): Hyperparameter search space containing bounds and scaling factors for mutation.
tune_dir (Path): Directory where evolution logs and results will be saved.
- evolve_csv (Path): Path to the CSV file where evolution logs are saved.
+ tune_csv (Path): Path to the CSV file where evolution logs are saved.
Methods:
_mutate(hyp: dict) -> dict:
@@ -52,7 +54,7 @@ class Tuner:
from ultralytics import YOLO
model = YOLO('yolov8n.pt')
- model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10, val=False, cache=True)
+ model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
```
"""
@@ -64,22 +66,23 @@ class Tuner:
args (dict, optional): Configuration for hyperparameter evolution.
"""
self.args = get_cfg(overrides=args)
- self.space = { # key: (min, max, gain(optionaL))
+ self.space = { # key: (min, max, gain(optional))
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
'lr0': (1e-5, 1e-1),
- 'lrf': (0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
- 'momentum': (0.6, 0.98, 0.3), # SGD momentum/Adam beta1
+ 'lrf': (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
+ 'momentum': (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
'weight_decay': (0.0, 0.001), # optimizer weight decay 5e-4
'warmup_epochs': (0.0, 5.0), # warmup epochs (fractions ok)
'warmup_momentum': (0.0, 0.95), # warmup initial momentum
- 'box': (0.02, 0.2), # box loss gain
+ 'box': (1.0, 20.0), # box loss gain
'cls': (0.2, 4.0), # cls loss gain (scale with pixels)
+ 'dfl': (0.4, 6.0), # dfl loss gain
'hsv_h': (0.0, 0.1), # image HSV-Hue augmentation (fraction)
'hsv_s': (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
'hsv_v': (0.0, 0.9), # image HSV-Value augmentation (fraction)
'degrees': (0.0, 45.0), # image rotation (+/- deg)
'translate': (0.0, 0.9), # image translation (+/- fraction)
- 'scale': (0.0, 0.9), # image scale (+/- gain)
+ 'scale': (0.0, 0.95), # image scale (+/- gain)
'shear': (0.0, 10.0), # image shear (+/- deg)
'perspective': (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
'flipud': (0.0, 1.0), # image flip up-down (probability)
@@ -87,11 +90,13 @@ class Tuner:
'mosaic': (0.0, 1.0), # image mixup (probability)
'mixup': (0.0, 1.0), # image mixup (probability)
'copy_paste': (0.0, 1.0)} # segment copy-paste (probability)
- self.tune_dir = get_save_dir(self.args, name='_tune')
- self.evolve_csv = self.tune_dir / 'evolve.csv'
+ self.tune_dir = get_save_dir(self.args, name='tune')
+ self.tune_csv = self.tune_dir / 'tune_results.csv'
self.callbacks = _callbacks or callbacks.get_default_callbacks()
+ self.prefix = colorstr('Tuner: ')
callbacks.add_integration_callbacks(self)
- LOGGER.info(f"Initialized Tuner instance with 'tune_dir={self.tune_dir}'.")
+ LOGGER.info(f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
+ f'{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning')
def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.2):
"""
@@ -106,9 +111,9 @@ class Tuner:
Returns:
(dict): A dictionary containing mutated hyperparameters.
"""
- if self.evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
+ if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
# Select parent(s)
- x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1)
+ x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1)
fitness = x[:, 0] # first column
n = min(n, len(x)) # number of previous results to consider
x = x[np.argsort(-fitness)][:n] # top n mutations
@@ -139,7 +144,7 @@ class Tuner:
return hyp
- def __call__(self, model=None, iterations=10, prefix=colorstr('Tuner:')):
+ def __call__(self, model=None, iterations=10, cleanup=True):
"""
Executes the hyperparameter evolution process when the Tuner instance is called.
@@ -152,54 +157,68 @@ class Tuner:
Args:
model (Model): A pre-initialized YOLO model to be used for training.
iterations (int): The number of generations to run the evolution for.
+ cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning.
Note:
- The method utilizes the `self.evolve_csv` Path object to read and log hyperparameters and fitness scores.
+ The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores.
Ensure this path is set correctly in the Tuner instance.
"""
t0 = time.time()
best_save_dir, best_metrics = None, None
- self.tune_dir.mkdir(parents=True, exist_ok=True)
+ (self.tune_dir / 'weights').mkdir(parents=True, exist_ok=True)
for i in range(iterations):
# Mutate hyperparameters
mutated_hyp = self._mutate()
- LOGGER.info(f'{prefix} Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}')
+ LOGGER.info(f'{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}')
+ metrics = {}
+ train_args = {**vars(self.args), **mutated_hyp}
+ save_dir = get_save_dir(get_cfg(train_args))
try:
- # Train YOLO model with mutated hyperparameters
- train_args = {**vars(self.args), **mutated_hyp}
- results = (deepcopy(model) or YOLO(self.args.model)).train(**train_args)
- fitness = results.fitness
- except Exception as e:
- LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i}\n{e}')
- fitness = 0.0
+ # Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
+ weights_dir = save_dir / 'weights'
+ cmd = ['yolo', 'train', *(f'{k}={v}' for k, v in train_args.items())]
+ assert subprocess.run(cmd, check=True).returncode == 0, 'training failed'
+ ckpt_file = weights_dir / ('best.pt' if (weights_dir / 'best.pt').exists() else 'last.pt')
+ metrics = torch.load(ckpt_file)['train_metrics']
- # Save results and mutated_hyp to evolve_csv
+ except Exception as e:
+ LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}')
+
+ # Save results and mutated_hyp to CSV
+ fitness = metrics.get('fitness', 0.0)
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
- headers = '' if self.evolve_csv.exists() else (','.join(['fitness_score'] + list(self.space.keys())) + '\n')
- with open(self.evolve_csv, 'a') as f:
+ headers = '' if self.tune_csv.exists() else (','.join(['fitness'] + list(self.space.keys())) + '\n')
+ with open(self.tune_csv, 'a') as f:
f.write(headers + ','.join(map(str, log_row)) + '\n')
- # Print tuning results
- x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1)
+ # Get best results
+ x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1)
fitness = x[:, 0] # first column
best_idx = fitness.argmax()
best_is_current = best_idx == i
if best_is_current:
- best_save_dir = results.save_dir
- best_metrics = {k: round(v, 5) for k, v in results.results_dict.items()}
- header = (f'{prefix} {i + 1} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
- f'{prefix} Results saved to {colorstr("bold", self.tune_dir)}\n'
- f'{prefix} Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
- f'{prefix} Best fitness metrics are {best_metrics}\n'
- f'{prefix} Best fitness model is {best_save_dir}\n'
- f'{prefix} Best fitness hyperparameters are printed below.\n')
+ best_save_dir = save_dir
+ best_metrics = {k: round(v, 5) for k, v in metrics.items()}
+ for ckpt in weights_dir.glob('*.pt'):
+ shutil.copy2(ckpt, self.tune_dir / 'weights')
+ elif cleanup:
+ shutil.rmtree(ckpt_file.parent) # remove iteration weights/ dir to reduce storage space
+ # Plot tune results
+ plot_tune_results(self.tune_csv)
+
+ # Save and print tune results
+ header = (f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
+ f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n'
+ f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
+ f'{self.prefix}Best fitness metrics are {best_metrics}\n'
+ f'{self.prefix}Best fitness model is {best_save_dir}\n'
+ f'{self.prefix}Best fitness hyperparameters are printed below.\n')
LOGGER.info('\n' + header)
-
- # Save turning results
- data = {k: float(x[0, i + 1]) for i, k in enumerate(self.space.keys())}
- header = header.replace(prefix, '#').replace('[1m/', '').replace('[0m', '') + '\n'
- yaml_save(self.tune_dir / 'best.yaml', data=data, header=header)
- yaml_print(self.tune_dir / 'best.yaml')
+ data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
+ yaml_save(self.tune_dir / 'best_hyperparameters.yaml',
+ data=data,
+ header=remove_colorstr(header.replace(self.prefix, '# ')) + '\n')
+ yaml_print(self.tune_dir / 'best_hyperparameters.yaml')
diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py
index 595de297..396f9b37 100644
--- a/ultralytics/hub/session.py
+++ b/ultralytics/hub/session.py
@@ -117,8 +117,7 @@ class HUBTrainingSession:
if data['status'] == 'new': # new model to start training
self.train_args = {
- # TODO: deprecate 'batch_size' key for 'batch' in 3Q23
- 'batch': data['batch' if ('batch' in data) else 'batch_size'],
+ 'batch': data['batch'],
'epochs': data['epochs'],
'imgsz': data['imgsz'],
'patience': data['patience'],
diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py
index 872ce5fd..bdf981a1 100644
--- a/ultralytics/utils/__init__.py
+++ b/ultralytics/utils/__init__.py
@@ -635,7 +635,33 @@ SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml'
def colorstr(*input):
- """Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')."""
+ """
+ Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes.
+ See https://en.wikipedia.org/wiki/ANSI_escape_code for more details.
+
+ This function can be called in two ways:
+ - colorstr('color', 'style', 'your string')
+ - colorstr('your string')
+
+ In the second form, 'blue' and 'bold' will be applied by default.
+
+ Args:
+ *input (str): A sequence of strings where the first n-1 strings are color and style arguments,
+ and the last string is the one to be colored.
+
+ Supported Colors and Styles:
+ Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'
+ Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow',
+ 'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white'
+ Misc: 'end', 'bold', 'underline'
+
+ Returns:
+ (str): The input string wrapped with ANSI escape codes for the specified color and style.
+
+ Examples:
+ >>> colorstr('blue', 'bold', 'hello world')
+ >>> '\033[34m\033[1mhello world\033[0m'
+ """
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
colors = {
'black': '\033[30m', # basic colors
@@ -660,6 +686,24 @@ def colorstr(*input):
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
+def remove_colorstr(input_string):
+ """
+ Removes ANSI escape codes from a string, effectively un-coloring it.
+
+ Args:
+ input_string (str): The string to remove color and style from.
+
+ Returns:
+ (str): A new string with all ANSI escape codes removed.
+
+ Examples:
+ >>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
+ >>> 'hello world'
+ """
+ ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
+ return ansi_escape.sub('', input_string)
+
+
class TryExcept(contextlib.ContextDecorator):
"""YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager."""
diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py
index 6237f133..3d70d3c7 100644
--- a/ultralytics/utils/plotting.py
+++ b/ultralytics/utils/plotting.py
@@ -498,13 +498,23 @@ def plot_images(images,
@plt_settings()
def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None):
"""
- Plot training results from results CSV file.
+ Plot training results from a results CSV file. The function supports various types of data including segmentation,
+ pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
+
+ Args:
+ file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'.
+ dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''.
+ segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False.
+ pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False.
+ classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False.
+ on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
+ Defaults to None.
Example:
```python
from ultralytics.utils.plotting import plot_results
- plot_results('path/to/results.csv')
+ plot_results('path/to/results.csv', segment=True)
```
"""
import pandas as pd
@@ -548,6 +558,92 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False,
on_plot(fname)
+def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none'):
+ """
+ Plots a scatter plot with points colored based on a 2D histogram.
+
+ Args:
+ v (array-like): Values for the x-axis.
+ f (array-like): Values for the y-axis.
+ bins (int, optional): Number of bins for the histogram. Defaults to 20.
+ cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'.
+ alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8.
+ edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'.
+
+ Examples:
+ >>> v = np.random.rand(100)
+ >>> f = np.random.rand(100)
+ >>> plt_color_scatter(v, f)
+ """
+
+ # Calculate 2D histogram and corresponding colors
+ hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
+ colors = [
+ hist[min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
+ min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1)] for i in range(len(v))]
+
+ # Scatter plot
+ plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
+
+
+def plot_tune_results(csv_file='tune_results.csv'):
+ """
+ Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key
+ in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
+
+ Args:
+ csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'.
+
+ Examples:
+ >>> plot_tune_results('path/to/tune_results.csv')
+ """
+
+ import pandas as pd
+ from scipy.ndimage import gaussian_filter1d
+
+ # Scatter plots for each hyperparameter
+ csv_file = Path(csv_file)
+ data = pd.read_csv(csv_file)
+ num_metrics_columns = 1
+ keys = [x.strip() for x in data.columns][num_metrics_columns:]
+ x = data.values
+ fitness = x[:, 0] # fitness
+ j = np.argmax(fitness) # max fitness index
+ n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
+ plt.figure(figsize=(10, 10), tight_layout=True)
+ for i, k in enumerate(keys):
+ v = x[:, i + num_metrics_columns]
+ mu = v[j] # best single result
+ plt.subplot(n, n, i + 1)
+ plt_color_scatter(v, fitness, cmap='viridis', alpha=.8, edgecolors='none')
+ plt.plot(mu, fitness.max(), 'k+', markersize=15)
+ plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters
+ plt.tick_params(axis='both', labelsize=8) # Set axis label size to 8
+ if i % n != 0:
+ plt.yticks([])
+
+ file = csv_file.with_name('tune_scatter_plots.png') # filename
+ plt.savefig(file, dpi=200)
+ plt.close()
+ LOGGER.info(f'Saved {file}')
+
+ # Fitness vs iteration
+ x = range(1, len(fitness) + 1)
+ plt.figure(figsize=(10, 6), tight_layout=True)
+ plt.plot(x, fitness, marker='o', linestyle='none', label='fitness')
+ plt.plot(x, gaussian_filter1d(fitness, sigma=3), ':', label='smoothed', linewidth=2) # smoothing line
+ plt.title('Fitness vs Iteration')
+ plt.xlabel('Iteration')
+ plt.ylabel('Fitness')
+ plt.grid(True)
+ plt.legend()
+
+ file = csv_file.with_name('tune_fitness.png') # filename
+ plt.savefig(file, dpi=200)
+ plt.close()
+ LOGGER.info(f'Saved {file}')
+
+
def output_to_target(output, max_det=300):
"""Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting."""
targets = []
diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py
index def7442b..dcab101c 100644
--- a/ultralytics/utils/torch_utils.py
+++ b/ultralytics/utils/torch_utils.py
@@ -395,12 +395,6 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
strip_optimizer(f)
```
"""
- # Use dill (if exists) to serialize the lambda functions where pickle does not do this
- try:
- import dill as pickle
- except ImportError:
- import pickle
-
x = torch.load(f, map_location=torch.device('cpu'))
if 'model' not in x:
LOGGER.info(f'Skipping {f}, not a valid Ultralytics model.')
@@ -419,8 +413,8 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None:
p.requires_grad = False
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
# x['model'].args = x['train_args']
- torch.save(x, s or f, pickle_module=pickle)
- mb = os.path.getsize(s or f) / 1E6 # filesize
+ torch.save(x, s or f)
+ mb = os.path.getsize(s or f) / 1E6 # file size
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")