mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Add EMA and model checkpointing (#49)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
27d6545117
commit
4291b9c31c
@ -9,6 +9,7 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from copy import deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
@ -29,6 +30,7 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
|
|||||||
from ultralytics.yolo.utils.checks import print_args
|
from ultralytics.yolo.utils.checks import print_args
|
||||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
||||||
from ultralytics.yolo.utils.modeling import get_model
|
from ultralytics.yolo.utils.modeling import get_model
|
||||||
|
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel
|
||||||
|
|
||||||
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
||||||
|
|
||||||
@ -63,6 +65,7 @@ class BaseTrainer:
|
|||||||
self.trainset, self.testset = self.get_dataset(self.data)
|
self.trainset, self.testset = self.get_dataset(self.data)
|
||||||
if self.args.model:
|
if self.args.model:
|
||||||
self.model = self.get_model(self.args.model)
|
self.model = self.get_model(self.args.model)
|
||||||
|
self.ema = None
|
||||||
|
|
||||||
# epoch level metrics
|
# epoch level metrics
|
||||||
self.metrics = {} # handle metrics returned by validator
|
self.metrics = {} # handle metrics returned by validator
|
||||||
@ -144,6 +147,7 @@ class BaseTrainer:
|
|||||||
self.validator = self.get_validator()
|
self.validator = self.get_validator()
|
||||||
print("created testloader :", rank)
|
print("created testloader :", rank)
|
||||||
self.console.info(self.progress_string())
|
self.console.info(self.progress_string())
|
||||||
|
self.ema = ModelEMA(self.model)
|
||||||
|
|
||||||
def _do_train(self, rank=-1, world_size=1):
|
def _do_train(self, rank=-1, world_size=1):
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -196,6 +200,7 @@ class BaseTrainer:
|
|||||||
if rank in [-1, 0]:
|
if rank in [-1, 0]:
|
||||||
# validation
|
# validation
|
||||||
# callback: on_val_start()
|
# callback: on_val_start()
|
||||||
|
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||||
self.validate()
|
self.validate()
|
||||||
# callback: on_val_end()
|
# callback: on_val_end()
|
||||||
|
|
||||||
@ -220,10 +225,10 @@ class BaseTrainer:
|
|||||||
ckpt = {
|
ckpt = {
|
||||||
'epoch': self.epoch,
|
'epoch': self.epoch,
|
||||||
'best_fitness': self.best_fitness,
|
'best_fitness': self.best_fitness,
|
||||||
'model': None, # deepcopy(ema.ema).half(), # deepcopy(de_parallel(model)).half(),
|
'model': deepcopy(de_parallel(self.model)).half(),
|
||||||
'ema': None, # deepcopy(ema.ema).half(),
|
'ema': deepcopy(self.ema.ema).half(),
|
||||||
'updates': None, # ema.updates,
|
'updates': self.ema.updates,
|
||||||
'optimizer': None, # optimizer.state_dict(),
|
'optimizer': self.optimizer.state_dict(),
|
||||||
'train_args': self.args,
|
'train_args': self.args,
|
||||||
'date': datetime.now().isoformat()}
|
'date': datetime.now().isoformat()}
|
||||||
|
|
||||||
@ -266,6 +271,8 @@ class BaseTrainer:
|
|||||||
self.scaler.step(self.optimizer)
|
self.scaler.step(self.optimizer)
|
||||||
self.scaler.update()
|
self.scaler.update()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
if self.ema:
|
||||||
|
self.ema.update(self.model)
|
||||||
|
|
||||||
def preprocess_batch(self, batch):
|
def preprocess_batch(self, batch):
|
||||||
"""
|
"""
|
||||||
|
@ -30,19 +30,16 @@ class BaseValidator:
|
|||||||
Supports validation of a pre-trained model if passed or a model being trained
|
Supports validation of a pre-trained model if passed or a model being trained
|
||||||
if trainer is passed (trainer gets priority).
|
if trainer is passed (trainer gets priority).
|
||||||
"""
|
"""
|
||||||
training = trainer is not None
|
self.training = trainer is not None
|
||||||
self.training = training
|
if self.training:
|
||||||
# trainer = trainer or self.trainer_class.get_trainer()
|
model = trainer.ema.ema or trainer.model
|
||||||
assert training or model is not None, "Either trainer or model is needed for validation"
|
|
||||||
if training:
|
|
||||||
model = trainer.model
|
|
||||||
self.args.half &= self.device.type != 'cpu'
|
self.args.half &= self.device.type != 'cpu'
|
||||||
# NOTE: half() inference in evaluation will make training stuck,
|
# NOTE: half() inference in evaluation will make training stuck,
|
||||||
# so I comment it out for now, I think we can reuse half mode after we add EMA.
|
# so I comment it out for now, I think we can reuse half mode after we add EMA.
|
||||||
# model = model.half() if self.args.half else model
|
model = model.half() if self.args.half else model.float()
|
||||||
else: # TODO: handle this when detectMultiBackend is supported
|
else: # TODO: handle this when detectMultiBackend is supported
|
||||||
|
assert model is not None, "Either trainer or model is needed for validation"
|
||||||
# model = DetectMultiBacked(model)
|
# model = DetectMultiBacked(model)
|
||||||
pass
|
|
||||||
# TODO: implement init_model_attributes()
|
# TODO: implement init_model_attributes()
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -50,7 +47,7 @@ class BaseValidator:
|
|||||||
loss = 0
|
loss = 0
|
||||||
n_batches = len(self.dataloader)
|
n_batches = len(self.dataloader)
|
||||||
desc = self.get_desc()
|
desc = self.get_desc()
|
||||||
bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format=TQDM_BAR_FORMAT)
|
bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
|
||||||
self.init_metrics(de_parallel(model))
|
self.init_metrics(de_parallel(model))
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch_i, batch in enumerate(bar):
|
for batch_i, batch in enumerate(bar):
|
||||||
@ -67,7 +64,7 @@ class BaseValidator:
|
|||||||
|
|
||||||
# loss
|
# loss
|
||||||
with dt[2]:
|
with dt[2]:
|
||||||
if training:
|
if self.training:
|
||||||
loss += trainer.criterion(preds, batch)[0]
|
loss += trainer.criterion(preds, batch)[0]
|
||||||
|
|
||||||
# pre-process predictions
|
# pre-process predictions
|
||||||
@ -82,7 +79,7 @@ class BaseValidator:
|
|||||||
self.print_results()
|
self.print_results()
|
||||||
|
|
||||||
# print speeds
|
# print speeds
|
||||||
if not training:
|
if not self.training:
|
||||||
t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
|
t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
|
||||||
# shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
|
# shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
|
@ -232,4 +232,4 @@ class ClassificationModel(BaseModel):
|
|||||||
elif nn.Conv2d in types:
|
elif nn.Conv2d in types:
|
||||||
i = types.index(nn.Conv2d) # nn.Conv2d index
|
i = types.index(nn.Conv2d) # nn.Conv2d index
|
||||||
if m[i].out_channels != nc:
|
if m[i].out_channels != nc:
|
||||||
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias)
|
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
|
||||||
|
@ -192,3 +192,34 @@ def is_parallel(model):
|
|||||||
def de_parallel(model):
|
def de_parallel(model):
|
||||||
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
|
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
|
||||||
return model.module if is_parallel(model) else model
|
return model.module if is_parallel(model) else model
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEMA:
|
||||||
|
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
||||||
|
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
||||||
|
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||||
|
# Create EMA
|
||||||
|
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||||||
|
self.updates = updates # number of EMA updates
|
||||||
|
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||||||
|
for p in self.ema.parameters():
|
||||||
|
p.requires_grad_(False)
|
||||||
|
|
||||||
|
def update(self, model):
|
||||||
|
# Update EMA parameters
|
||||||
|
self.updates += 1
|
||||||
|
d = self.decay(self.updates)
|
||||||
|
|
||||||
|
msd = de_parallel(model).state_dict() # model state_dict
|
||||||
|
for k, v in self.ema.state_dict().items():
|
||||||
|
if v.dtype.is_floating_point: # true for FP16 and FP32
|
||||||
|
v *= d
|
||||||
|
v += (1 - d) * msd[k].detach()
|
||||||
|
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
|
||||||
|
|
||||||
|
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
||||||
|
# Update EMA attributes
|
||||||
|
copy_attr(self.ema, model, include, exclude)
|
||||||
|
@ -159,11 +159,11 @@ class SegmentationTrainer(BaseTrainer):
|
|||||||
|
|
||||||
return tcls, tbox, indices, anch, tidxs, xywhn
|
return tcls, tbox, indices, anch, tidxs, xywhn
|
||||||
|
|
||||||
if self.model.training:
|
if len(preds) == 2: # eval
|
||||||
p, proto, = preds
|
p, proto, = preds
|
||||||
else:
|
else: # len(3) train
|
||||||
p, proto, train_out = preds
|
_, proto, p = preds
|
||||||
p = train_out
|
|
||||||
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||||
masks = batch["masks"]
|
masks = batch["masks"]
|
||||||
targets, masks = targets.to(self.device), masks.to(self.device).float()
|
targets, masks = targets.to(self.device), masks.to(self.device).float()
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
Loading…
x
Reference in New Issue
Block a user