mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
Add Adamax, NAdam, RAdam optimizers (#2969)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f502b50365
commit
451cf8b647
@ -56,7 +56,7 @@ is important to carefully tune and experiment with these settings to achieve the
|
|||||||
task.
|
task.
|
||||||
|
|
||||||
| Key | Value | Description |
|
| Key | Value | Description |
|
||||||
|-------------------|----------|-----------------------------------------------------------------------------|
|
|-------------------|----------|-----------------------------------------------------------------------------------|
|
||||||
| `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml |
|
| `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml |
|
||||||
| `data` | `None` | path to data file, i.e. coco128.yaml |
|
| `data` | `None` | path to data file, i.e. coco128.yaml |
|
||||||
| `epochs` | `100` | number of epochs to train for |
|
| `epochs` | `100` | number of epochs to train for |
|
||||||
@ -72,7 +72,7 @@ task.
|
|||||||
| `name` | `None` | experiment name |
|
| `name` | `None` | experiment name |
|
||||||
| `exist_ok` | `False` | whether to overwrite existing experiment |
|
| `exist_ok` | `False` | whether to overwrite existing experiment |
|
||||||
| `pretrained` | `False` | whether to use a pretrained model |
|
| `pretrained` | `False` | whether to use a pretrained model |
|
||||||
| `optimizer` | `'SGD'` | optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] |
|
| `optimizer` | `'auto'` | optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto] |
|
||||||
| `verbose` | `False` | whether to print verbose output |
|
| `verbose` | `False` | whether to print verbose output |
|
||||||
| `seed` | `0` | random seed for reproducibility |
|
| `seed` | `0` | random seed for reproducibility |
|
||||||
| `deterministic` | `True` | whether to enable deterministic mode |
|
| `deterministic` | `True` | whether to enable deterministic mode |
|
||||||
|
@ -78,7 +78,7 @@ include:
|
|||||||
The training settings for YOLO models encompass various hyperparameters and configurations used during the training process. These settings influence the model's performance, speed, and accuracy. Key training settings include batch size, learning rate, momentum, and weight decay. Additionally, the choice of optimizer, loss function, and training dataset composition can impact the training process. Careful tuning and experimentation with these settings are crucial for optimizing performance.
|
The training settings for YOLO models encompass various hyperparameters and configurations used during the training process. These settings influence the model's performance, speed, and accuracy. Key training settings include batch size, learning rate, momentum, and weight decay. Additionally, the choice of optimizer, loss function, and training dataset composition can impact the training process. Careful tuning and experimentation with these settings are crucial for optimizing performance.
|
||||||
|
|
||||||
| Key | Value | Description |
|
| Key | Value | Description |
|
||||||
|-------------------|----------|-----------------------------------------------------------------------------|
|
|-------------------|----------|-----------------------------------------------------------------------------------|
|
||||||
| `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml |
|
| `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml |
|
||||||
| `data` | `None` | path to data file, i.e. coco128.yaml |
|
| `data` | `None` | path to data file, i.e. coco128.yaml |
|
||||||
| `epochs` | `100` | number of epochs to train for |
|
| `epochs` | `100` | number of epochs to train for |
|
||||||
@ -94,7 +94,7 @@ The training settings for YOLO models encompass various hyperparameters and conf
|
|||||||
| `name` | `None` | experiment name |
|
| `name` | `None` | experiment name |
|
||||||
| `exist_ok` | `False` | whether to overwrite existing experiment |
|
| `exist_ok` | `False` | whether to overwrite existing experiment |
|
||||||
| `pretrained` | `False` | whether to use a pretrained model |
|
| `pretrained` | `False` | whether to use a pretrained model |
|
||||||
| `optimizer` | `'SGD'` | optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] |
|
| `optimizer` | `'auto'` | optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto] |
|
||||||
| `verbose` | `False` | whether to print verbose output |
|
| `verbose` | `False` | whether to print verbose output |
|
||||||
| `seed` | `0` | random seed for reproducibility |
|
| `seed` | `0` | random seed for reproducibility |
|
||||||
| `deterministic` | `True` | whether to enable deterministic mode |
|
| `deterministic` | `True` | whether to enable deterministic mode |
|
||||||
|
@ -20,7 +20,7 @@ project: # project name
|
|||||||
name: # experiment name, results saved to 'project/name' directory
|
name: # experiment name, results saved to 'project/name' directory
|
||||||
exist_ok: False # whether to overwrite existing experiment
|
exist_ok: False # whether to overwrite existing experiment
|
||||||
pretrained: False # whether to use a pretrained model
|
pretrained: False # whether to use a pretrained model
|
||||||
optimizer: SGD # optimizer to use, choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
|
optimizer: auto # optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
|
||||||
verbose: True # whether to print verbose output
|
verbose: True # whether to print verbose output
|
||||||
seed: 0 # random seed for reproducibility
|
seed: 0 # random seed for reproducibility
|
||||||
deterministic: True # whether to enable deterministic mode
|
deterministic: True # whether to enable deterministic mode
|
||||||
|
@ -5,6 +5,7 @@ Train a model on a dataset
|
|||||||
Usage:
|
Usage:
|
||||||
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
|
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
|
||||||
"""
|
"""
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
@ -14,11 +15,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
from torch import distributed as dist
|
||||||
import torch.nn as nn
|
from torch import nn, optim
|
||||||
from torch.cuda import amp
|
from torch.cuda import amp
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import lr_scheduler
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||||
@ -234,24 +234,8 @@ class BaseTrainer:
|
|||||||
SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. '
|
SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. '
|
||||||
'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')
|
'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')
|
||||||
|
|
||||||
# Optimizer
|
|
||||||
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
|
||||||
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
|
||||||
self.optimizer = self.build_optimizer(model=self.model,
|
|
||||||
name=self.args.optimizer,
|
|
||||||
lr=self.args.lr0,
|
|
||||||
momentum=self.args.momentum,
|
|
||||||
decay=weight_decay)
|
|
||||||
# Scheduler
|
|
||||||
if self.args.cos_lr:
|
|
||||||
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
|
||||||
else:
|
|
||||||
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
|
||||||
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
|
||||||
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
batch_size = self.batch_size // world_size if world_size > 1 else self.batch_size
|
batch_size = self.batch_size // max(world_size, 1)
|
||||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode='train')
|
||||||
if RANK in (-1, 0):
|
if RANK in (-1, 0):
|
||||||
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
|
self.test_loader = self.get_dataloader(self.testset, batch_size=batch_size * 2, rank=-1, mode='val')
|
||||||
@ -261,6 +245,24 @@ class BaseTrainer:
|
|||||||
self.ema = ModelEMA(self.model)
|
self.ema = ModelEMA(self.model)
|
||||||
if self.args.plots and not self.args.v5loader:
|
if self.args.plots and not self.args.v5loader:
|
||||||
self.plot_training_labels()
|
self.plot_training_labels()
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||||
|
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||||
|
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
|
||||||
|
self.optimizer = self.build_optimizer(model=self.model,
|
||||||
|
name=self.args.optimizer,
|
||||||
|
lr=self.args.lr0,
|
||||||
|
momentum=self.args.momentum,
|
||||||
|
decay=weight_decay,
|
||||||
|
iterations=iterations)
|
||||||
|
# Scheduler
|
||||||
|
if self.args.cos_lr:
|
||||||
|
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
||||||
|
else:
|
||||||
|
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||||
|
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||||
|
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
||||||
self.resume_training(ckpt)
|
self.resume_training(ckpt)
|
||||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||||
self.run_callbacks('on_pretrain_routine_end')
|
self.run_callbacks('on_pretrain_routine_end')
|
||||||
@ -603,24 +605,30 @@ class BaseTrainer:
|
|||||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||||
|
|
||||||
@staticmethod
|
def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
||||||
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
|
||||||
"""
|
"""
|
||||||
Builds an optimizer with the specified parameters and parameter groups.
|
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate,
|
||||||
|
momentum, weight decay, and number of iterations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): model to optimize
|
model (torch.nn.Module): The model for which to build an optimizer.
|
||||||
name (str): name of the optimizer to use
|
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
|
||||||
lr (float): learning rate
|
based on the number of iterations. Default: 'auto'.
|
||||||
momentum (float): momentum
|
lr (float, optional): The learning rate for the optimizer. Default: 0.001.
|
||||||
decay (float): weight decay
|
momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
|
||||||
|
decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
|
||||||
|
iterations (float, optional): The number of iterations, which determines the optimizer if
|
||||||
|
name is 'auto'. Default: 1e5.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
optimizer (torch.optim.Optimizer): the built optimizer
|
(torch.optim.Optimizer): The constructed optimizer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
g = [], [], [] # optimizer parameter groups
|
g = [], [], [] # optimizer parameter groups
|
||||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||||
|
if name == 'auto':
|
||||||
|
name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 6000 else ('NAdam', 0.001, 0.9)
|
||||||
|
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for NAdam
|
||||||
|
|
||||||
for module_name, module in model.named_modules():
|
for module_name, module in model.named_modules():
|
||||||
for param_name, param in module.named_parameters(recurse=False):
|
for param_name, param in module.named_parameters(recurse=False):
|
||||||
@ -632,19 +640,21 @@ class BaseTrainer:
|
|||||||
else: # weight (with decay)
|
else: # weight (with decay)
|
||||||
g[0].append(param)
|
g[0].append(param)
|
||||||
|
|
||||||
if name == 'Adam':
|
if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
|
||||||
optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum
|
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
||||||
elif name == 'AdamW':
|
|
||||||
optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
|
||||||
elif name == 'RMSProp':
|
elif name == 'RMSProp':
|
||||||
optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
||||||
elif name == 'SGD':
|
elif name == 'SGD':
|
||||||
optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Optimizer {name} not implemented.')
|
raise NotImplementedError(
|
||||||
|
f"Optimizer '{name}' not found in list of available optimizers "
|
||||||
|
f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
|
||||||
|
'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
|
||||||
|
|
||||||
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
||||||
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
||||||
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
|
LOGGER.info(
|
||||||
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
|
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
||||||
|
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
|
||||||
return optimizer
|
return optimizer
|
||||||
|
@ -14,7 +14,7 @@ except ImportError:
|
|||||||
tune = None
|
tune = None
|
||||||
|
|
||||||
default_space = {
|
default_space = {
|
||||||
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'RMSProp']),
|
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
||||||
'lr0': tune.uniform(1e-5, 1e-1),
|
'lr0': tune.uniform(1e-5, 1e-1),
|
||||||
'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
|
'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
|
||||||
'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
|
'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user