diff --git a/docs/reference/utils/__init__.md b/docs/reference/utils/__init__.md index 31c22ad1..7c32363f 100644 --- a/docs/reference/utils/__init__.md +++ b/docs/reference/utils/__init__.md @@ -133,6 +133,10 @@ keywords: Ultralytics, Utils, utilitarian functions, colorstr, yaml_save, set_lo ## ::: ultralytics.utils.colorstr

+--- +## ::: ultralytics.utils.remove_colorstr +

+ --- ## ::: ultralytics.utils.threaded

diff --git a/docs/reference/utils/plotting.md b/docs/reference/utils/plotting.md index a0c28b13..15414eb5 100644 --- a/docs/reference/utils/plotting.md +++ b/docs/reference/utils/plotting.md @@ -33,6 +33,14 @@ keywords: Ultralytics, plotting, utils, color annotation, label plotting, image ## ::: ultralytics.utils.plotting.plot_results

+--- +## ::: ultralytics.utils.plotting.plt_color_scatter +

+ +--- +## ::: ultralytics.utils.plotting.plot_tune_results +

+ --- ## ::: ultralytics.utils.plotting.output_to_target

diff --git a/tests/test_cuda.py b/tests/test_cuda.py index 44dc159c..7063271b 100644 --- a/tests/test_cuda.py +++ b/tests/test_cuda.py @@ -94,8 +94,8 @@ def test_model_ray_tune(): @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') def test_model_tune(): - YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', imgsz=32, epochs=1, iterations=2, device='cpu') - YOLO('yolov8n-cls.pt').tune(data='imagenet10', imgsz=32, epochs=1, iterations=2, device='cpu') + YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu') + YOLO('yolov8n-cls.pt').tune(data='imagenet10', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu') @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available') @@ -107,27 +107,22 @@ def test_pycocotools(): # Download annotations after each dataset downloads first url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/' - validator = DetectionValidator(args={'model': 'yolov8n.pt', 'data': 'coco8.yaml', 'save_json': True, 'imgsz': 64}) + args = {'model': 'yolov8n.pt', 'data': 'coco8.yaml', 'save_json': True, 'imgsz': 64} + validator = DetectionValidator(args=args) validator() validator.is_coco = True download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8/annotations') _ = validator.eval_json(validator.stats) - validator = SegmentationValidator(args={ - 'model': 'yolov8n-seg.pt', - 'data': 'coco8-seg.yaml', - 'save_json': True, - 'imgsz': 64}) + args = {'model': 'yolov8n-seg.pt', 'data': 'coco8-seg.yaml', 'save_json': True, 'imgsz': 64} + validator = SegmentationValidator(args=args) validator() validator.is_coco = True download(f'{url}instances_val2017.json', dir=DATASETS_DIR / 'coco8-seg/annotations') _ = validator.eval_json(validator.stats) - validator = PoseValidator(args={ - 'model': 'yolov8n-pose.pt', - 'data': 'coco8-pose.yaml', - 'save_json': True, - 'imgsz': 64}) + args = {'model': 'yolov8n-pose.pt', 'data': 'coco8-pose.yaml', 'save_json': True, 'imgsz': 64} + validator = PoseValidator(args=args) validator() validator.is_coco = True download(f'{url}person_keypoints_val2017.json', dir=DATASETS_DIR / 'coco8-pose/annotations') diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index a0ae59fe..7a9f7f8c 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.173' +__version__ = '8.0.174' from ultralytics.models import RTDETR, SAM, YOLO from ultralytics.models.fastsam import FastSAM diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 7bc48f27..fadbb372 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -358,7 +358,7 @@ def entrypoint(debug=''): if '=' in a: try: k, v = parse_key_value_pair(a) - if k == 'cfg': # custom.yaml passed + if k == 'cfg' and v is not None: # custom.yaml passed LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}') overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'} else: diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index ac573236..0dbdd64a 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -361,7 +361,7 @@ class Model: else: from .tuner import Tuner - custom = {'plots': False, 'save': False} # method defaults + custom = {} # method defaults args = {**self.overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 4ff4229a..9ccb3f36 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -423,7 +423,10 @@ class BaseTrainer: self.run_callbacks('teardown') def save_model(self): - """Save model checkpoints based on various conditions.""" + """Save model training checkpoints with additional metadata.""" + import pandas as pd # scope for faster startup + metrics = {**self.metrics, **{'fitness': self.fitness}} + results = {k.strip(): v for k, v in pd.read_csv(self.save_dir / 'results.csv').to_dict(orient='list').items()} ckpt = { 'epoch': self.epoch, 'best_fitness': self.best_fitness, @@ -432,22 +435,17 @@ class BaseTrainer: 'updates': self.ema.updates, 'optimizer': self.optimizer.state_dict(), 'train_args': vars(self.args), # save as dict + 'train_metrics': metrics, + 'train_results': results, 'date': datetime.now().isoformat(), 'version': __version__} - # Use dill (if exists) to serialize the lambda functions where pickle does not do this - try: - import dill as pickle - except ImportError: - import pickle - - # Save last, best and delete - torch.save(ckpt, self.last, pickle_module=pickle) + # Save last and best + torch.save(ckpt, self.last) if self.best_fitness == self.fitness: - torch.save(ckpt, self.best, pickle_module=pickle) - if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0): - torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle) - del ckpt + torch.save(ckpt, self.best) + if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0): + torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt') @staticmethod def get_dataset(data): @@ -654,6 +652,9 @@ class BaseTrainer: g = [], [], [] # optimizer parameter groups bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d() if name == 'auto': + LOGGER.info(f"{colorstr('optimizer:')} 'optimizer=auto' found, " + f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and " + f"determining best 'optimizer', 'lr0' and 'momentum' automatically... ") nc = getattr(model, 'nc', 10) # number of classes lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9) diff --git a/ultralytics/engine/tuner.py b/ultralytics/engine/tuner.py index 0702690d..d60a56c2 100644 --- a/ultralytics/engine/tuner.py +++ b/ultralytics/engine/tuner.py @@ -13,18 +13,20 @@ Example: from ultralytics import YOLO model = YOLO('yolov8n.pt') - model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10) + model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False) ``` """ import random +import shutil +import subprocess import time -from copy import deepcopy import numpy as np +import torch -from ultralytics import YOLO from ultralytics.cfg import get_cfg, get_save_dir -from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, yaml_print, yaml_save +from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, remove_colorstr, yaml_print, yaml_save +from ultralytics.utils.plotting import plot_tune_results class Tuner: @@ -37,7 +39,7 @@ class Tuner: Attributes: space (dict): Hyperparameter search space containing bounds and scaling factors for mutation. tune_dir (Path): Directory where evolution logs and results will be saved. - evolve_csv (Path): Path to the CSV file where evolution logs are saved. + tune_csv (Path): Path to the CSV file where evolution logs are saved. Methods: _mutate(hyp: dict) -> dict: @@ -52,7 +54,7 @@ class Tuner: from ultralytics import YOLO model = YOLO('yolov8n.pt') - model.tune(data='coco8.yaml', imgsz=640, epochs=100, iterations=10, val=False, cache=True) + model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False) ``` """ @@ -64,22 +66,23 @@ class Tuner: args (dict, optional): Configuration for hyperparameter evolution. """ self.args = get_cfg(overrides=args) - self.space = { # key: (min, max, gain(optionaL)) + self.space = { # key: (min, max, gain(optional)) # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), 'lr0': (1e-5, 1e-1), - 'lrf': (0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) - 'momentum': (0.6, 0.98, 0.3), # SGD momentum/Adam beta1 + 'lrf': (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf) + 'momentum': (0.7, 0.98, 0.3), # SGD momentum/Adam beta1 'weight_decay': (0.0, 0.001), # optimizer weight decay 5e-4 'warmup_epochs': (0.0, 5.0), # warmup epochs (fractions ok) 'warmup_momentum': (0.0, 0.95), # warmup initial momentum - 'box': (0.02, 0.2), # box loss gain + 'box': (1.0, 20.0), # box loss gain 'cls': (0.2, 4.0), # cls loss gain (scale with pixels) + 'dfl': (0.4, 6.0), # dfl loss gain 'hsv_h': (0.0, 0.1), # image HSV-Hue augmentation (fraction) 'hsv_s': (0.0, 0.9), # image HSV-Saturation augmentation (fraction) 'hsv_v': (0.0, 0.9), # image HSV-Value augmentation (fraction) 'degrees': (0.0, 45.0), # image rotation (+/- deg) 'translate': (0.0, 0.9), # image translation (+/- fraction) - 'scale': (0.0, 0.9), # image scale (+/- gain) + 'scale': (0.0, 0.95), # image scale (+/- gain) 'shear': (0.0, 10.0), # image shear (+/- deg) 'perspective': (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 'flipud': (0.0, 1.0), # image flip up-down (probability) @@ -87,11 +90,13 @@ class Tuner: 'mosaic': (0.0, 1.0), # image mixup (probability) 'mixup': (0.0, 1.0), # image mixup (probability) 'copy_paste': (0.0, 1.0)} # segment copy-paste (probability) - self.tune_dir = get_save_dir(self.args, name='_tune') - self.evolve_csv = self.tune_dir / 'evolve.csv' + self.tune_dir = get_save_dir(self.args, name='tune') + self.tune_csv = self.tune_dir / 'tune_results.csv' self.callbacks = _callbacks or callbacks.get_default_callbacks() + self.prefix = colorstr('Tuner: ') callbacks.add_integration_callbacks(self) - LOGGER.info(f"Initialized Tuner instance with 'tune_dir={self.tune_dir}'.") + LOGGER.info(f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n" + f'{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning') def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.2): """ @@ -106,9 +111,9 @@ class Tuner: Returns: (dict): A dictionary containing mutated hyperparameters. """ - if self.evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate + if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate # Select parent(s) - x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1) + x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1) fitness = x[:, 0] # first column n = min(n, len(x)) # number of previous results to consider x = x[np.argsort(-fitness)][:n] # top n mutations @@ -139,7 +144,7 @@ class Tuner: return hyp - def __call__(self, model=None, iterations=10, prefix=colorstr('Tuner:')): + def __call__(self, model=None, iterations=10, cleanup=True): """ Executes the hyperparameter evolution process when the Tuner instance is called. @@ -152,54 +157,68 @@ class Tuner: Args: model (Model): A pre-initialized YOLO model to be used for training. iterations (int): The number of generations to run the evolution for. + cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning. Note: - The method utilizes the `self.evolve_csv` Path object to read and log hyperparameters and fitness scores. + The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores. Ensure this path is set correctly in the Tuner instance. """ t0 = time.time() best_save_dir, best_metrics = None, None - self.tune_dir.mkdir(parents=True, exist_ok=True) + (self.tune_dir / 'weights').mkdir(parents=True, exist_ok=True) for i in range(iterations): # Mutate hyperparameters mutated_hyp = self._mutate() - LOGGER.info(f'{prefix} Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}') + LOGGER.info(f'{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}') + metrics = {} + train_args = {**vars(self.args), **mutated_hyp} + save_dir = get_save_dir(get_cfg(train_args)) try: - # Train YOLO model with mutated hyperparameters - train_args = {**vars(self.args), **mutated_hyp} - results = (deepcopy(model) or YOLO(self.args.model)).train(**train_args) - fitness = results.fitness - except Exception as e: - LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i}\n{e}') - fitness = 0.0 + # Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang) + weights_dir = save_dir / 'weights' + cmd = ['yolo', 'train', *(f'{k}={v}' for k, v in train_args.items())] + assert subprocess.run(cmd, check=True).returncode == 0, 'training failed' + ckpt_file = weights_dir / ('best.pt' if (weights_dir / 'best.pt').exists() else 'last.pt') + metrics = torch.load(ckpt_file)['train_metrics'] - # Save results and mutated_hyp to evolve_csv + except Exception as e: + LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}') + + # Save results and mutated_hyp to CSV + fitness = metrics.get('fitness', 0.0) log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()] - headers = '' if self.evolve_csv.exists() else (','.join(['fitness_score'] + list(self.space.keys())) + '\n') - with open(self.evolve_csv, 'a') as f: + headers = '' if self.tune_csv.exists() else (','.join(['fitness'] + list(self.space.keys())) + '\n') + with open(self.tune_csv, 'a') as f: f.write(headers + ','.join(map(str, log_row)) + '\n') - # Print tuning results - x = np.loadtxt(self.evolve_csv, ndmin=2, delimiter=',', skiprows=1) + # Get best results + x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1) fitness = x[:, 0] # first column best_idx = fitness.argmax() best_is_current = best_idx == i if best_is_current: - best_save_dir = results.save_dir - best_metrics = {k: round(v, 5) for k, v in results.results_dict.items()} - header = (f'{prefix} {i + 1} iterations complete ✅ ({time.time() - t0:.2f}s)\n' - f'{prefix} Results saved to {colorstr("bold", self.tune_dir)}\n' - f'{prefix} Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n' - f'{prefix} Best fitness metrics are {best_metrics}\n' - f'{prefix} Best fitness model is {best_save_dir}\n' - f'{prefix} Best fitness hyperparameters are printed below.\n') + best_save_dir = save_dir + best_metrics = {k: round(v, 5) for k, v in metrics.items()} + for ckpt in weights_dir.glob('*.pt'): + shutil.copy2(ckpt, self.tune_dir / 'weights') + elif cleanup: + shutil.rmtree(ckpt_file.parent) # remove iteration weights/ dir to reduce storage space + # Plot tune results + plot_tune_results(self.tune_csv) + + # Save and print tune results + header = (f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n' + f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n' + f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n' + f'{self.prefix}Best fitness metrics are {best_metrics}\n' + f'{self.prefix}Best fitness model is {best_save_dir}\n' + f'{self.prefix}Best fitness hyperparameters are printed below.\n') LOGGER.info('\n' + header) - - # Save turning results - data = {k: float(x[0, i + 1]) for i, k in enumerate(self.space.keys())} - header = header.replace(prefix, '#').replace('/', '').replace('', '') + '\n' - yaml_save(self.tune_dir / 'best.yaml', data=data, header=header) - yaml_print(self.tune_dir / 'best.yaml') + data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())} + yaml_save(self.tune_dir / 'best_hyperparameters.yaml', + data=data, + header=remove_colorstr(header.replace(self.prefix, '# ')) + '\n') + yaml_print(self.tune_dir / 'best_hyperparameters.yaml') diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index 595de297..396f9b37 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -117,8 +117,7 @@ class HUBTrainingSession: if data['status'] == 'new': # new model to start training self.train_args = { - # TODO: deprecate 'batch_size' key for 'batch' in 3Q23 - 'batch': data['batch' if ('batch' in data) else 'batch_size'], + 'batch': data['batch'], 'epochs': data['epochs'], 'imgsz': data['imgsz'], 'patience': data['patience'], diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py index 872ce5fd..bdf981a1 100644 --- a/ultralytics/utils/__init__.py +++ b/ultralytics/utils/__init__.py @@ -635,7 +635,33 @@ SETTINGS_YAML = USER_CONFIG_DIR / 'settings.yaml' def colorstr(*input): - """Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world').""" + """ + Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes. + See https://en.wikipedia.org/wiki/ANSI_escape_code for more details. + + This function can be called in two ways: + - colorstr('color', 'style', 'your string') + - colorstr('your string') + + In the second form, 'blue' and 'bold' will be applied by default. + + Args: + *input (str): A sequence of strings where the first n-1 strings are color and style arguments, + and the last string is the one to be colored. + + Supported Colors and Styles: + Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white' + Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow', + 'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white' + Misc: 'end', 'bold', 'underline' + + Returns: + (str): The input string wrapped with ANSI escape codes for the specified color and style. + + Examples: + >>> colorstr('blue', 'bold', 'hello world') + >>> '\033[34m\033[1mhello world\033[0m' + """ *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string colors = { 'black': '\033[30m', # basic colors @@ -660,6 +686,24 @@ def colorstr(*input): return ''.join(colors[x] for x in args) + f'{string}' + colors['end'] +def remove_colorstr(input_string): + """ + Removes ANSI escape codes from a string, effectively un-coloring it. + + Args: + input_string (str): The string to remove color and style from. + + Returns: + (str): A new string with all ANSI escape codes removed. + + Examples: + >>> remove_colorstr(colorstr('blue', 'bold', 'hello world')) + >>> 'hello world' + """ + ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + return ansi_escape.sub('', input_string) + + class TryExcept(contextlib.ContextDecorator): """YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager.""" diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py index 6237f133..3d70d3c7 100644 --- a/ultralytics/utils/plotting.py +++ b/ultralytics/utils/plotting.py @@ -498,13 +498,23 @@ def plot_images(images, @plt_settings() def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, classify=False, on_plot=None): """ - Plot training results from results CSV file. + Plot training results from a results CSV file. The function supports various types of data including segmentation, + pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located. + + Args: + file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'. + dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''. + segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False. + pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False. + classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False. + on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument. + Defaults to None. Example: ```python from ultralytics.utils.plotting import plot_results - plot_results('path/to/results.csv') + plot_results('path/to/results.csv', segment=True) ``` """ import pandas as pd @@ -548,6 +558,92 @@ def plot_results(file='path/to/results.csv', dir='', segment=False, pose=False, on_plot(fname) +def plt_color_scatter(v, f, bins=20, cmap='viridis', alpha=0.8, edgecolors='none'): + """ + Plots a scatter plot with points colored based on a 2D histogram. + + Args: + v (array-like): Values for the x-axis. + f (array-like): Values for the y-axis. + bins (int, optional): Number of bins for the histogram. Defaults to 20. + cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'. + alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8. + edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'. + + Examples: + >>> v = np.random.rand(100) + >>> f = np.random.rand(100) + >>> plt_color_scatter(v, f) + """ + + # Calculate 2D histogram and corresponding colors + hist, xedges, yedges = np.histogram2d(v, f, bins=bins) + colors = [ + hist[min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1), + min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1)] for i in range(len(v))] + + # Scatter plot + plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors) + + +def plot_tune_results(csv_file='tune_results.csv'): + """ + Plot the evolution results stored in an 'tune_results.csv' file. The function generates a scatter plot for each key + in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots. + + Args: + csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'. + + Examples: + >>> plot_tune_results('path/to/tune_results.csv') + """ + + import pandas as pd + from scipy.ndimage import gaussian_filter1d + + # Scatter plots for each hyperparameter + csv_file = Path(csv_file) + data = pd.read_csv(csv_file) + num_metrics_columns = 1 + keys = [x.strip() for x in data.columns][num_metrics_columns:] + x = data.values + fitness = x[:, 0] # fitness + j = np.argmax(fitness) # max fitness index + n = math.ceil(len(keys) ** 0.5) # columns and rows in plot + plt.figure(figsize=(10, 10), tight_layout=True) + for i, k in enumerate(keys): + v = x[:, i + num_metrics_columns] + mu = v[j] # best single result + plt.subplot(n, n, i + 1) + plt_color_scatter(v, fitness, cmap='viridis', alpha=.8, edgecolors='none') + plt.plot(mu, fitness.max(), 'k+', markersize=15) + plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters + plt.tick_params(axis='both', labelsize=8) # Set axis label size to 8 + if i % n != 0: + plt.yticks([]) + + file = csv_file.with_name('tune_scatter_plots.png') # filename + plt.savefig(file, dpi=200) + plt.close() + LOGGER.info(f'Saved {file}') + + # Fitness vs iteration + x = range(1, len(fitness) + 1) + plt.figure(figsize=(10, 6), tight_layout=True) + plt.plot(x, fitness, marker='o', linestyle='none', label='fitness') + plt.plot(x, gaussian_filter1d(fitness, sigma=3), ':', label='smoothed', linewidth=2) # smoothing line + plt.title('Fitness vs Iteration') + plt.xlabel('Iteration') + plt.ylabel('Fitness') + plt.grid(True) + plt.legend() + + file = csv_file.with_name('tune_fitness.png') # filename + plt.savefig(file, dpi=200) + plt.close() + LOGGER.info(f'Saved {file}') + + def output_to_target(output, max_det=300): """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting.""" targets = [] diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index def7442b..dcab101c 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -395,12 +395,6 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None: strip_optimizer(f) ``` """ - # Use dill (if exists) to serialize the lambda functions where pickle does not do this - try: - import dill as pickle - except ImportError: - import pickle - x = torch.load(f, map_location=torch.device('cpu')) if 'model' not in x: LOGGER.info(f'Skipping {f}, not a valid Ultralytics model.') @@ -419,8 +413,8 @@ def strip_optimizer(f: Union[str, Path] = 'best.pt', s: str = '') -> None: p.requires_grad = False x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys # x['model'].args = x['train_args'] - torch.save(x, s or f, pickle_module=pickle) - mb = os.path.getsize(s or f) / 1E6 # filesize + torch.save(x, s or f) + mb = os.path.getsize(s or f) / 1E6 # file size LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")