From 451cf8b6473ebd3363f74e5321203889ec9c6261 Mon Sep 17 00:00:00 2001
From: Glenn Jocher <glenn.jocher@ultralytics.com>
Date: Sun, 4 Jun 2023 22:35:50 +0200
Subject: [PATCH] Add Adamax, NAdam, RAdam optimizers (#2969)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
 docs/modes/train.md                | 94 +++++++++++++++---------------
 docs/usage/cfg.md                  | 94 +++++++++++++++---------------
 ultralytics/yolo/cfg/default.yaml  |  2 +-
 ultralytics/yolo/engine/trainer.py | 86 +++++++++++++++------------
 ultralytics/yolo/utils/tuner.py    |  2 +-
 5 files changed, 144 insertions(+), 134 deletions(-)

diff --git a/docs/modes/train.md b/docs/modes/train.md
index d5603605..44923733 100644
--- a/docs/modes/train.md
+++ b/docs/modes/train.md
@@ -55,50 +55,50 @@ include the choice of optimizer, the choice of loss function, and the size and c
 is important to carefully tune and experiment with these settings to achieve the best possible performance for a given
 task.
 
-| Key               | Value    | Description                                                                 |
-|-------------------|----------|-----------------------------------------------------------------------------|
-| `model`           | `None`   | path to model file, i.e. yolov8n.pt, yolov8n.yaml                           |
-| `data`            | `None`   | path to data file, i.e. coco128.yaml                                        |
-| `epochs`          | `100`    | number of epochs to train for                                               |
-| `patience`        | `50`     | epochs to wait for no observable improvement for early stopping of training |
-| `batch`           | `16`     | number of images per batch (-1 for AutoBatch)                               |
-| `imgsz`           | `640`    | size of input images as integer or w,h                                      |
-| `save`            | `True`   | save train checkpoints and predict results                                  |
-| `save_period`     | `-1`     | Save checkpoint every x epochs (disabled if < 1)                            |
-| `cache`           | `False`  | True/ram, disk or False. Use cache for data loading                         |
-| `device`          | `None`   | device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu        |
-| `workers`         | `8`      | number of worker threads for data loading (per RANK if DDP)                 |
-| `project`         | `None`   | project name                                                                |
-| `name`            | `None`   | experiment name                                                             |
-| `exist_ok`        | `False`  | whether to overwrite existing experiment                                    |
-| `pretrained`      | `False`  | whether to use a pretrained model                                           |
-| `optimizer`       | `'SGD'`  | optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']               |
-| `verbose`         | `False`  | whether to print verbose output                                             |
-| `seed`            | `0`      | random seed for reproducibility                                             |
-| `deterministic`   | `True`   | whether to enable deterministic mode                                        |
-| `single_cls`      | `False`  | train multi-class data as single-class                                      |
-| `rect`            | `False`  | rectangular training with each batch collated for minimum padding           |
-| `cos_lr`          | `False`  | use cosine learning rate scheduler                                          |
-| `close_mosaic`    | `0`      | (int) disable mosaic augmentation for final epochs                          |
-| `resume`          | `False`  | resume training from last checkpoint                                        |
-| `amp`             | `True`   | Automatic Mixed Precision (AMP) training, choices=[True, False]             |
-| `fraction`        | `1.0`    | dataset fraction to train on (default is 1.0, all images in train set)      |
-| `profile`         | `False`  | profile ONNX and TensorRT speeds during training for loggers                |
-| `lr0`             | `0.01`   | initial learning rate (i.e. SGD=1E-2, Adam=1E-3)                            |
-| `lrf`             | `0.01`   | final learning rate (lr0 * lrf)                                             |
-| `momentum`        | `0.937`  | SGD momentum/Adam beta1                                                     |
-| `weight_decay`    | `0.0005` | optimizer weight decay 5e-4                                                 |
-| `warmup_epochs`   | `3.0`    | warmup epochs (fractions ok)                                                |
-| `warmup_momentum` | `0.8`    | warmup initial momentum                                                     |
-| `warmup_bias_lr`  | `0.1`    | warmup initial bias lr                                                      |
-| `box`             | `7.5`    | box loss gain                                                               |
-| `cls`             | `0.5`    | cls loss gain (scale with pixels)                                           |
-| `dfl`             | `1.5`    | dfl loss gain                                                               |
-| `pose`            | `12.0`   | pose loss gain (pose-only)                                                  |
-| `kobj`            | `2.0`    | keypoint obj loss gain (pose-only)                                          |
-| `label_smoothing` | `0.0`    | label smoothing (fraction)                                                  |
-| `nbs`             | `64`     | nominal batch size                                                          |
-| `overlap_mask`    | `True`   | masks should overlap during training (segment train only)                   |
-| `mask_ratio`      | `4`      | mask downsample ratio (segment train only)                                  |
-| `dropout`         | `0.0`    | use dropout regularization (classify train only)                            |
-| `val`             | `True`   | validate/test during training                                               |
+| Key               | Value    | Description                                                                       |
+|-------------------|----------|-----------------------------------------------------------------------------------|
+| `model`           | `None`   | path to model file, i.e. yolov8n.pt, yolov8n.yaml                                 |
+| `data`            | `None`   | path to data file, i.e. coco128.yaml                                              |
+| `epochs`          | `100`    | number of epochs to train for                                                     |
+| `patience`        | `50`     | epochs to wait for no observable improvement for early stopping of training       |
+| `batch`           | `16`     | number of images per batch (-1 for AutoBatch)                                     |
+| `imgsz`           | `640`    | size of input images as integer or w,h                                            |
+| `save`            | `True`   | save train checkpoints and predict results                                        |
+| `save_period`     | `-1`     | Save checkpoint every x epochs (disabled if < 1)                                  |
+| `cache`           | `False`  | True/ram, disk or False. Use cache for data loading                               |
+| `device`          | `None`   | device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu              |
+| `workers`         | `8`      | number of worker threads for data loading (per RANK if DDP)                       |
+| `project`         | `None`   | project name                                                                      |
+| `name`            | `None`   | experiment name                                                                   |
+| `exist_ok`        | `False`  | whether to overwrite existing experiment                                          |
+| `pretrained`      | `False`  | whether to use a pretrained model                                                 |
+| `optimizer`       | `'auto'` | optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto] |
+| `verbose`         | `False`  | whether to print verbose output                                                   |
+| `seed`            | `0`      | random seed for reproducibility                                                   |
+| `deterministic`   | `True`   | whether to enable deterministic mode                                              |
+| `single_cls`      | `False`  | train multi-class data as single-class                                            |
+| `rect`            | `False`  | rectangular training with each batch collated for minimum padding                 |
+| `cos_lr`          | `False`  | use cosine learning rate scheduler                                                |
+| `close_mosaic`    | `0`      | (int) disable mosaic augmentation for final epochs                                |
+| `resume`          | `False`  | resume training from last checkpoint                                              |
+| `amp`             | `True`   | Automatic Mixed Precision (AMP) training, choices=[True, False]                   |
+| `fraction`        | `1.0`    | dataset fraction to train on (default is 1.0, all images in train set)            |
+| `profile`         | `False`  | profile ONNX and TensorRT speeds during training for loggers                      |
+| `lr0`             | `0.01`   | initial learning rate (i.e. SGD=1E-2, Adam=1E-3)                                  |
+| `lrf`             | `0.01`   | final learning rate (lr0 * lrf)                                                   |
+| `momentum`        | `0.937`  | SGD momentum/Adam beta1                                                           |
+| `weight_decay`    | `0.0005` | optimizer weight decay 5e-4                                                       |
+| `warmup_epochs`   | `3.0`    | warmup epochs (fractions ok)                                                      |
+| `warmup_momentum` | `0.8`    | warmup initial momentum                                                           |
+| `warmup_bias_lr`  | `0.1`    | warmup initial bias lr                                                            |
+| `box`             | `7.5`    | box loss gain                                                                     |
+| `cls`             | `0.5`    | cls loss gain (scale with pixels)                                                 |
+| `dfl`             | `1.5`    | dfl loss gain                                                                     |
+| `pose`            | `12.0`   | pose loss gain (pose-only)                                                        |
+| `kobj`            | `2.0`    | keypoint obj loss gain (pose-only)                                                |
+| `label_smoothing` | `0.0`    | label smoothing (fraction)                                                        |
+| `nbs`             | `64`     | nominal batch size                                                                |
+| `overlap_mask`    | `True`   | masks should overlap during training (segment train only)                         |
+| `mask_ratio`      | `4`      | mask downsample ratio (segment train only)                                        |
+| `dropout`         | `0.0`    | use dropout regularization (classify train only)                                  |
+| `val`             | `True`   | validate/test during training                                                     |
diff --git a/docs/usage/cfg.md b/docs/usage/cfg.md
index 0c990025..b7ddb496 100644
--- a/docs/usage/cfg.md
+++ b/docs/usage/cfg.md
@@ -77,53 +77,53 @@ include:
 
 The training settings for YOLO models encompass various hyperparameters and configurations used during the training process. These settings influence the model's performance, speed, and accuracy. Key training settings include batch size, learning rate, momentum, and weight decay. Additionally, the choice of optimizer, loss function, and training dataset composition can impact the training process. Careful tuning and experimentation with these settings are crucial for optimizing performance.
 
-| Key               | Value    | Description                                                                 |
-|-------------------|----------|-----------------------------------------------------------------------------|
-| `model`           | `None`   | path to model file, i.e. yolov8n.pt, yolov8n.yaml                           |
-| `data`            | `None`   | path to data file, i.e. coco128.yaml                                        |
-| `epochs`          | `100`    | number of epochs to train for                                               |
-| `patience`        | `50`     | epochs to wait for no observable improvement for early stopping of training |
-| `batch`           | `16`     | number of images per batch (-1 for AutoBatch)                               |
-| `imgsz`           | `640`    | size of input images as integer or w,h                                      |
-| `save`            | `True`   | save train checkpoints and predict results                                  |
-| `save_period`     | `-1`     | Save checkpoint every x epochs (disabled if < 1)                            |
-| `cache`           | `False`  | True/ram, disk or False. Use cache for data loading                         |
-| `device`          | `None`   | device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu        |
-| `workers`         | `8`      | number of worker threads for data loading (per RANK if DDP)                 |
-| `project`         | `None`   | project name                                                                |
-| `name`            | `None`   | experiment name                                                             |
-| `exist_ok`        | `False`  | whether to overwrite existing experiment                                    |
-| `pretrained`      | `False`  | whether to use a pretrained model                                           |
-| `optimizer`       | `'SGD'`  | optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']               |
-| `verbose`         | `False`  | whether to print verbose output                                             |
-| `seed`            | `0`      | random seed for reproducibility                                             |
-| `deterministic`   | `True`   | whether to enable deterministic mode                                        |
-| `single_cls`      | `False`  | train multi-class data as single-class                                      |
-| `rect`            | `False`  | rectangular training with each batch collated for minimum padding           |
-| `cos_lr`          | `False`  | use cosine learning rate scheduler                                          |
-| `close_mosaic`    | `0`      | (int) disable mosaic augmentation for final epochs                          |
-| `resume`          | `False`  | resume training from last checkpoint                                        |
-| `amp`             | `True`   | Automatic Mixed Precision (AMP) training, choices=[True, False]             |
-| `fraction`        | `1.0`    | dataset fraction to train on (default is 1.0, all images in train set)      |
-| `profile`         | `False`  | profile ONNX and TensorRT speeds during training for loggers                |
-| `lr0`             | `0.01`   | initial learning rate (i.e. SGD=1E-2, Adam=1E-3)                            |
-| `lrf`             | `0.01`   | final learning rate (lr0 * lrf)                                             |
-| `momentum`        | `0.937`  | SGD momentum/Adam beta1                                                     |
-| `weight_decay`    | `0.0005` | optimizer weight decay 5e-4                                                 |
-| `warmup_epochs`   | `3.0`    | warmup epochs (fractions ok)                                                |
-| `warmup_momentum` | `0.8`    | warmup initial momentum                                                     |
-| `warmup_bias_lr`  | `0.1`    | warmup initial bias lr                                                      |
-| `box`             | `7.5`    | box loss gain                                                               |
-| `cls`             | `0.5`    | cls loss gain (scale with pixels)                                           |
-| `dfl`             | `1.5`    | dfl loss gain                                                               |
-| `pose`            | `12.0`   | pose loss gain (pose-only)                                                  |
-| `kobj`            | `2.0`    | keypoint obj loss gain (pose-only)                                          |
-| `label_smoothing` | `0.0`    | label smoothing (fraction)                                                  |
-| `nbs`             | `64`     | nominal batch size                                                          |
-| `overlap_mask`    | `True`   | masks should overlap during training (segment train only)                   |
-| `mask_ratio`      | `4`      | mask downsample ratio (segment train only)                                  |
-| `dropout`         | `0.0`    | use dropout regularization (classify train only)                            |
-| `val`             | `True`   | validate/test during training                                               |
+| Key               | Value    | Description                                                                       |
+|-------------------|----------|-----------------------------------------------------------------------------------|
+| `model`           | `None`   | path to model file, i.e. yolov8n.pt, yolov8n.yaml                                 |
+| `data`            | `None`   | path to data file, i.e. coco128.yaml                                              |
+| `epochs`          | `100`    | number of epochs to train for                                                     |
+| `patience`        | `50`     | epochs to wait for no observable improvement for early stopping of training       |
+| `batch`           | `16`     | number of images per batch (-1 for AutoBatch)                                     |
+| `imgsz`           | `640`    | size of input images as integer or w,h                                            |
+| `save`            | `True`   | save train checkpoints and predict results                                        |
+| `save_period`     | `-1`     | Save checkpoint every x epochs (disabled if < 1)                                  |
+| `cache`           | `False`  | True/ram, disk or False. Use cache for data loading                               |
+| `device`          | `None`   | device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu              |
+| `workers`         | `8`      | number of worker threads for data loading (per RANK if DDP)                       |
+| `project`         | `None`   | project name                                                                      |
+| `name`            | `None`   | experiment name                                                                   |
+| `exist_ok`        | `False`  | whether to overwrite existing experiment                                          |
+| `pretrained`      | `False`  | whether to use a pretrained model                                                 |
+| `optimizer`       | `'auto'` | optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto] |
+| `verbose`         | `False`  | whether to print verbose output                                                   |
+| `seed`            | `0`      | random seed for reproducibility                                                   |
+| `deterministic`   | `True`   | whether to enable deterministic mode                                              |
+| `single_cls`      | `False`  | train multi-class data as single-class                                            |
+| `rect`            | `False`  | rectangular training with each batch collated for minimum padding                 |
+| `cos_lr`          | `False`  | use cosine learning rate scheduler                                                |
+| `close_mosaic`    | `0`      | (int) disable mosaic augmentation for final epochs                                |
+| `resume`          | `False`  | resume training from last checkpoint                                              |
+| `amp`             | `True`   | Automatic Mixed Precision (AMP) training, choices=[True, False]                   |
+| `fraction`        | `1.0`    | dataset fraction to train on (default is 1.0, all images in train set)            |
+| `profile`         | `False`  | profile ONNX and TensorRT speeds during training for loggers                      |
+| `lr0`             | `0.01`   | initial learning rate (i.e. SGD=1E-2, Adam=1E-3)                                  |
+| `lrf`             | `0.01`   | final learning rate (lr0 * lrf)                                                   |
+| `momentum`        | `0.937`  | SGD momentum/Adam beta1                                                           |
+| `weight_decay`    | `0.0005` | optimizer weight decay 5e-4                                                       |
+| `warmup_epochs`   | `3.0`    | warmup epochs (fractions ok)                                                      |
+| `warmup_momentum` | `0.8`    | warmup initial momentum                                                           |
+| `warmup_bias_lr`  | `0.1`    | warmup initial bias lr                                                            |
+| `box`             | `7.5`    | box loss gain                                                                     |
+| `cls`             | `0.5`    | cls loss gain (scale with pixels)                                                 |
+| `dfl`             | `1.5`    | dfl loss gain                                                                     |
+| `pose`            | `12.0`   | pose loss gain (pose-only)                                                        |
+| `kobj`            | `2.0`    | keypoint obj loss gain (pose-only)                                                |
+| `label_smoothing` | `0.0`    | label smoothing (fraction)                                                        |
+| `nbs`             | `64`     | nominal batch size                                                                |
+| `overlap_mask`    | `True`   | masks should overlap during training (segment train only)                         |
+| `mask_ratio`      | `4`      | mask downsample ratio (segment train only)                                        |
+| `dropout`         | `0.0`    | use dropout regularization (classify train only)                                  |
+| `val`             | `True`   | validate/test during training                                                     |
 
 [Train Guide](../modes/train.md){ .md-button .md-button--primary}
 
diff --git a/ultralytics/yolo/cfg/default.yaml b/ultralytics/yolo/cfg/default.yaml
index 41b94492..35be2e9d 100644
--- a/ultralytics/yolo/cfg/default.yaml
+++ b/ultralytics/yolo/cfg/default.yaml
@@ -20,7 +20,7 @@ project:  # project name
 name:  # experiment name, results saved to 'project/name' directory
 exist_ok: False  # whether to overwrite existing experiment
 pretrained: False  # whether to use a pretrained model
-optimizer: SGD  # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
+optimizer: auto  # optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
 verbose: True  # whether to print verbose output
 seed: 0  # random seed for reproducibility
 deterministic: True  # whether to enable deterministic mode
diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py
index b39ad135..b0da4df5 100644
--- a/ultralytics/yolo/engine/trainer.py
+++ b/ultralytics/yolo/engine/trainer.py
@@ -5,6 +5,7 @@ Train a model on a dataset
 Usage:
     $ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
 """
+import math
 import os
 import subprocess
 import time
@@ -14,11 +15,10 @@ from pathlib import Path
 
 import numpy as np
 import torch
-import torch.distributed as dist
-import torch.nn as nn
+from torch import distributed as dist
+from torch import nn, optim
 from torch.cuda import amp
 from torch.nn.parallel import DistributedDataParallel as DDP
-from torch.optim import lr_scheduler
 from tqdm import tqdm
 
 from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
@@ -234,24 +234,8 @@ class BaseTrainer:
                 SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. '
                             'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')
 
-        # Optimizer
-        self.accumulate = max(round(self.args.nbs / self.batch_size), 1)  # accumulate loss before optimizing
-        weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs  # scale weight_decay
-        self.optimizer = self.build_optimizer(model=self.model,
-                                              name=self.args.optimizer,
-                                              lr=self.args.lr0,
-                                              momentum=self.args.momentum,
-                                              decay=weight_decay)
-        # Scheduler
-        if self.args.cos_lr:
-            self.lf = one_cycle(1, self.args.lrf, self.epochs)  # cosine 1->hyp['lrf']
-        else:
-            self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf  # linear
-        self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
-        self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
-
         # Dataloaders
-        batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
+        batch_size = self.batch_size // max(world_size, 1)
         self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
         if RANK in (-1, 0):
             self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
@@ -261,6 +245,24 @@ class BaseTrainer:
             self.ema = ModelEMA(self.model)
             if self.args.plots and not self.args.v5loader:
                 self.plot_training_labels()
+
+        # Optimizer
+        self.accumulate = max(round(self.args.nbs / self.batch_size), 1)  # accumulate loss before optimizing
+        weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs  # scale weight_decay
+        iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
+        self.optimizer = self.build_optimizer(model=self.model,
+                                              name=self.args.optimizer,
+                                              lr=self.args.lr0,
+                                              momentum=self.args.momentum,
+                                              decay=weight_decay,
+                                              iterations=iterations)
+        # Scheduler
+        if self.args.cos_lr:
+            self.lf = one_cycle(1, self.args.lrf, self.epochs)  # cosine 1->hyp['lrf']
+        else:
+            self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf  # linear
+        self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
+        self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
         self.resume_training(ckpt)
         self.scheduler.last_epoch = self.start_epoch - 1  # do not move
         self.run_callbacks('on_pretrain_routine_end')
@@ -603,24 +605,30 @@ class BaseTrainer:
             if hasattr(self.train_loader.dataset, 'close_mosaic'):
                 self.train_loader.dataset.close_mosaic(hyp=self.args)
 
-    @staticmethod
-    def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
+    def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
         """
-        Builds an optimizer with the specified parameters and parameter groups.
+        Constructs an optimizer for the given model, based on the specified optimizer name, learning rate,
+        momentum, weight decay, and number of iterations.
 
         Args:
-            model (nn.Module): model to optimize
-            name (str): name of the optimizer to use
-            lr (float): learning rate
-            momentum (float): momentum
-            decay (float): weight decay
+            model (torch.nn.Module): The model for which to build an optimizer.
+            name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
+                based on the number of iterations. Default: 'auto'.
+            lr (float, optional): The learning rate for the optimizer. Default: 0.001.
+            momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
+            decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
+            iterations (float, optional): The number of iterations, which determines the optimizer if
+                name is 'auto'. Default: 1e5.
 
         Returns:
-            optimizer (torch.optim.Optimizer): the built optimizer
+            (torch.optim.Optimizer): The constructed optimizer.
         """
 
         g = [], [], []  # optimizer parameter groups
         bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)  # normalization layers, i.e. BatchNorm2d()
+        if name == 'auto':
+            name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 6000 else ('NAdam', 0.001, 0.9)
+            self.args.warmup_bias_lr = 0.0  # no higher than 0.01 for NAdam
 
         for module_name, module in model.named_modules():
             for param_name, param in module.named_parameters(recurse=False):
@@ -632,19 +640,21 @@ class BaseTrainer:
                 else:  # weight (with decay)
                     g[0].append(param)
 
-        if name == 'Adam':
-            optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999))  # adjust beta1 to momentum
-        elif name == 'AdamW':
-            optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
+        if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
+            optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
         elif name == 'RMSProp':
-            optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
+            optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
         elif name == 'SGD':
-            optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
+            optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
         else:
-            raise NotImplementedError(f'Optimizer {name} not implemented.')
+            raise NotImplementedError(
+                f"Optimizer '{name}' not found in list of available optimizers "
+                f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
+                'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
 
         optimizer.add_param_group({'params': g[0], 'weight_decay': decay})  # add g0 with weight_decay
         optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})  # add g1 (BatchNorm2d weights)
-        LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
-                    f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
+        LOGGER.info(
+            f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
+            f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
         return optimizer
diff --git a/ultralytics/yolo/utils/tuner.py b/ultralytics/yolo/utils/tuner.py
index 54e1b010..9f57677a 100644
--- a/ultralytics/yolo/utils/tuner.py
+++ b/ultralytics/yolo/utils/tuner.py
@@ -14,7 +14,7 @@ except ImportError:
     tune = None
 
 default_space = {
-    # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'RMSProp']),
+    # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
     'lr0': tune.uniform(1e-5, 1e-1),
     'lrf': tune.uniform(0.01, 1.0),  # final OneCycleLR learning rate (lr0 * lrf)
     'momentum': tune.uniform(0.6, 0.98),  # SGD momentum/Adam beta1