mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Add warmup and accumulation (#52)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
298287298d
commit
d7df1770fa
@ -10,6 +10,7 @@ from datetime import datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
@ -17,6 +18,7 @@ import torch.nn as nn
|
|||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from torch.cuda import amp
|
from torch.cuda import amp
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import lr_scheduler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import ultralytics.yolo.utils as utils
|
import ultralytics.yolo.utils as utils
|
||||||
@ -26,7 +28,7 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
|
|||||||
from ultralytics.yolo.utils.checks import print_args
|
from ultralytics.yolo.utils.checks import print_args
|
||||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
||||||
from ultralytics.yolo.utils.modeling import get_model
|
from ultralytics.yolo.utils.modeling import get_model
|
||||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel
|
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, one_cycle
|
||||||
|
|
||||||
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
||||||
|
|
||||||
@ -63,6 +65,10 @@ class BaseTrainer:
|
|||||||
self.model = self.get_model(self.args.model)
|
self.model = self.get_model(self.args.model)
|
||||||
self.ema = None
|
self.ema = None
|
||||||
|
|
||||||
|
# Optimization utils init
|
||||||
|
self.lf = None
|
||||||
|
self.scheduler = None
|
||||||
|
|
||||||
# epoch level metrics
|
# epoch level metrics
|
||||||
self.metrics = {} # handle metrics returned by validator
|
self.metrics = {} # handle metrics returned by validator
|
||||||
self.best_fitness = None
|
self.best_fitness = None
|
||||||
@ -131,12 +137,23 @@ class BaseTrainer:
|
|||||||
"""
|
"""
|
||||||
Builds dataloaders and optimizer on correct rank process
|
Builds dataloaders and optimizer on correct rank process
|
||||||
"""
|
"""
|
||||||
|
# Optimizer
|
||||||
self.set_model_attributes()
|
self.set_model_attributes()
|
||||||
|
accumulate = max(round(self.args.nbs / self.args.batch_size), 1) # accumulate loss before optimizing
|
||||||
|
self.args.weight_decay *= self.args.batch_size * accumulate / self.args.nbs # scale weight_decay
|
||||||
self.optimizer = build_optimizer(model=self.model,
|
self.optimizer = build_optimizer(model=self.model,
|
||||||
name=self.args.optimizer,
|
name=self.args.optimizer,
|
||||||
lr=self.args.lr0,
|
lr=self.args.lr0,
|
||||||
momentum=self.args.momentum,
|
momentum=self.args.momentum,
|
||||||
decay=self.args.weight_decay)
|
decay=self.args.weight_decay)
|
||||||
|
# Scheduler
|
||||||
|
if self.args.cos_lr:
|
||||||
|
self.lf = one_cycle(1, self.args.lrf, self.args.epochs) # cosine 1->hyp['lrf']
|
||||||
|
else:
|
||||||
|
self.lf = lambda x: (1 - x / self.args.epochs) * (1.0 - self.args.lrf + self.args.lrf) # linear
|
||||||
|
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||||
|
|
||||||
|
# dataloaders
|
||||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank)
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank)
|
||||||
if rank in {0, -1}:
|
if rank in {0, -1}:
|
||||||
print(" Creating testloader rank :", rank)
|
print(" Creating testloader rank :", rank)
|
||||||
@ -154,10 +171,13 @@ class BaseTrainer:
|
|||||||
self.trigger_callbacks("before_train")
|
self.trigger_callbacks("before_train")
|
||||||
self._setup_train(rank)
|
self._setup_train(rank)
|
||||||
|
|
||||||
self.epoch = 1
|
self.epoch = 0
|
||||||
self.epoch_time = None
|
self.epoch_time = None
|
||||||
self.epoch_time_start = time.time()
|
self.epoch_time_start = time.time()
|
||||||
self.train_time_start = time.time()
|
self.train_time_start = time.time()
|
||||||
|
nb = len(self.train_loader) # number of batches
|
||||||
|
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
||||||
|
last_opt_step = -1
|
||||||
for epoch in range(self.args.epochs):
|
for epoch in range(self.args.epochs):
|
||||||
self.trigger_callbacks("on_epoch_start")
|
self.trigger_callbacks("on_epoch_start")
|
||||||
self.model.train()
|
self.model.train()
|
||||||
@ -170,7 +190,18 @@ class BaseTrainer:
|
|||||||
# forward
|
# forward
|
||||||
batch = self.preprocess_batch(batch)
|
batch = self.preprocess_batch(batch)
|
||||||
|
|
||||||
# TODO: warmup, multiscale
|
# warmup
|
||||||
|
ni = i + nb * epoch
|
||||||
|
if ni <= nw:
|
||||||
|
xi = [0, nw] # x interp
|
||||||
|
accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.args.batch_size]).round())
|
||||||
|
for j, x in enumerate(self.optimizer.param_groups):
|
||||||
|
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
||||||
|
x['lr'] = np.interp(
|
||||||
|
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)])
|
||||||
|
if 'momentum' in x:
|
||||||
|
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
||||||
|
|
||||||
preds = self.model(batch["img"])
|
preds = self.model(batch["img"])
|
||||||
self.loss, self.loss_items = self.criterion(preds, batch)
|
self.loss, self.loss_items = self.criterion(preds, batch)
|
||||||
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
|
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
|
||||||
@ -181,7 +212,9 @@ class BaseTrainer:
|
|||||||
self.scaler.scale(self.loss).backward()
|
self.scaler.scale(self.loss).backward()
|
||||||
|
|
||||||
# optimize
|
# optimize
|
||||||
self.optimizer_step()
|
if ni - last_opt_step >= accumulate:
|
||||||
|
self.optimizer_step()
|
||||||
|
last_opt_step = ni
|
||||||
|
|
||||||
# log
|
# log
|
||||||
mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
||||||
|
@ -27,6 +27,7 @@ single_cls: False # train multi-class data as single-class
|
|||||||
image_weights: False # use weighted image selection for training
|
image_weights: False # use weighted image selection for training
|
||||||
shuffle: True
|
shuffle: True
|
||||||
rect: False # support rectangular training
|
rect: False # support rectangular training
|
||||||
|
cos_lr: False # Use cosine LR scheduler
|
||||||
overlap_mask: True # Segmentation masks overlap
|
overlap_mask: True # Segmentation masks overlap
|
||||||
mask_ratio: 4 # Segmentation mask downsample ratio
|
mask_ratio: 4 # Segmentation mask downsample ratio
|
||||||
|
|
||||||
@ -71,6 +72,7 @@ mosaic: 1.0 # image mosaic (probability)
|
|||||||
mixup: 0.0 # image mixup (probability)
|
mixup: 0.0 # image mixup (probability)
|
||||||
copy_paste: 0.0 # segment copy-paste (probability)
|
copy_paste: 0.0 # segment copy-paste (probability)
|
||||||
label_smoothing: 0.0
|
label_smoothing: 0.0
|
||||||
|
nbs: 64 # nominal batch size
|
||||||
# anchors: 3
|
# anchors: 3
|
||||||
|
|
||||||
# Hydra configs --------------------------------------------------------------------------------------------------------
|
# Hydra configs --------------------------------------------------------------------------------------------------------
|
||||||
|
@ -194,6 +194,11 @@ def de_parallel(model):
|
|||||||
return model.module if is_parallel(model) else model
|
return model.module if is_parallel(model) else model
|
||||||
|
|
||||||
|
|
||||||
|
def one_cycle(y1=0.0, y2=1.0, steps=100):
|
||||||
|
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
|
||||||
|
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
|
||||||
|
|
||||||
|
|
||||||
class ModelEMA:
|
class ModelEMA:
|
||||||
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
||||||
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user