mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-07-07 13:44:23 +08:00
Add AutoBatch from YOLOv5 (#145)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
172cef2d20
commit
99275814f1
@ -27,6 +27,7 @@ from ultralytics.yolo.configs import get_config
|
|||||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||||
from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
||||||
yaml_save)
|
yaml_save)
|
||||||
|
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||||
from ultralytics.yolo.utils.checks import check_file, print_args
|
from ultralytics.yolo.utils.checks import check_file, print_args
|
||||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||||
from ultralytics.yolo.utils.files import get_latest_run, increment_path
|
from ultralytics.yolo.utils.files import get_latest_run, increment_path
|
||||||
@ -135,7 +136,7 @@ class BaseTrainer:
|
|||||||
self.fitness = None
|
self.fitness = None
|
||||||
self.loss = None
|
self.loss = None
|
||||||
self.tloss = None
|
self.tloss = None
|
||||||
self.loss_names = None
|
self.loss_names = ['Loss']
|
||||||
self.csv = self.save_dir / 'results.csv'
|
self.csv = self.save_dir / 'results.csv'
|
||||||
self.plot_idx = [0, 1, 2]
|
self.plot_idx = [0, 1, 2]
|
||||||
|
|
||||||
@ -192,6 +193,15 @@ class BaseTrainer:
|
|||||||
self.set_model_attributes()
|
self.set_model_attributes()
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
self.model = DDP(self.model, device_ids=[rank])
|
self.model = DDP(self.model, device_ids=[rank])
|
||||||
|
|
||||||
|
# Batch size
|
||||||
|
if self.batch_size == -1:
|
||||||
|
if RANK == -1: # single-GPU only, estimate best batch size
|
||||||
|
self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
|
||||||
|
else:
|
||||||
|
SyntaxError('batch=-1 to use AutoBatch is only available in Single-GPU training. '
|
||||||
|
'Please pass a valid batch size value for Multi-GPU DDP training, i.e. batch=16')
|
||||||
|
|
||||||
# Optimizer
|
# Optimizer
|
||||||
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||||
self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||||
|
@ -78,7 +78,7 @@ class BaseValidator:
|
|||||||
self.device = trainer.device
|
self.device = trainer.device
|
||||||
self.data = trainer.data
|
self.data = trainer.data
|
||||||
model = trainer.ema.ema or trainer.model
|
model = trainer.ema.ema or trainer.model
|
||||||
self.args.half &= self.device.type != 'cpu'
|
self.args.half = self.device.type != 'cpu' # force FP16 val during training
|
||||||
model = model.half() if self.args.half else model.float()
|
model = model.half() if self.args.half else model.float()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
||||||
|
72
ultralytics/yolo/utils/autobatch.py
Normal file
72
ultralytics/yolo/utils/autobatch.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
||||||
|
"""
|
||||||
|
Auto-batch utils
|
||||||
|
"""
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics.yolo.utils import LOGGER, colorstr
|
||||||
|
from ultralytics.yolo.utils.torch_utils import profile
|
||||||
|
|
||||||
|
|
||||||
|
def check_train_batch_size(model, imgsz=640, amp=True):
|
||||||
|
# Check YOLOv5 training batch size
|
||||||
|
with torch.cuda.amp.autocast(amp):
|
||||||
|
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
||||||
|
|
||||||
|
|
||||||
|
def autobatch(model, imgsz=640, fraction=0.7, batch_size=16):
|
||||||
|
# Automatically estimate best YOLOv5 batch size to use `fraction` of available CUDA memory
|
||||||
|
# Usage:
|
||||||
|
# import torch
|
||||||
|
# from utils.autobatch import autobatch
|
||||||
|
# model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
|
||||||
|
# print(autobatch(model))
|
||||||
|
|
||||||
|
# Check device
|
||||||
|
prefix = colorstr('AutoBatch: ')
|
||||||
|
LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
|
||||||
|
device = next(model.parameters()).device # get model device
|
||||||
|
if device.type == 'cpu':
|
||||||
|
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
|
||||||
|
return batch_size
|
||||||
|
if torch.backends.cudnn.benchmark:
|
||||||
|
LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
|
||||||
|
return batch_size
|
||||||
|
|
||||||
|
# Inspect CUDA memory
|
||||||
|
gb = 1 << 30 # bytes to GiB (1024 ** 3)
|
||||||
|
d = str(device).upper() # 'CUDA:0'
|
||||||
|
properties = torch.cuda.get_device_properties(device) # device properties
|
||||||
|
t = properties.total_memory / gb # GiB total
|
||||||
|
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
|
||||||
|
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
|
||||||
|
f = t - (r + a) # GiB free
|
||||||
|
LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
|
||||||
|
|
||||||
|
# Profile batch sizes
|
||||||
|
batch_sizes = [1, 2, 4, 8, 16]
|
||||||
|
try:
|
||||||
|
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
|
||||||
|
results = profile(img, model, n=3, device=device)
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.warning(f'{prefix}{e}')
|
||||||
|
|
||||||
|
# Fit a solution
|
||||||
|
y = [x[2] for x in results if x] # memory [2]
|
||||||
|
p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
|
||||||
|
b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
|
||||||
|
if None in results: # some sizes failed
|
||||||
|
i = results.index(None) # first fail index
|
||||||
|
if b >= batch_sizes[i]: # y intercept above failure point
|
||||||
|
b = batch_sizes[max(i - 1, 0)] # select prior safe point
|
||||||
|
if b < 1 or b > 1024: # b outside of safe range
|
||||||
|
b = batch_size
|
||||||
|
LOGGER.warning(f'{prefix}WARNING ⚠️ CUDA anomaly detected, recommend restart environment and retry command.')
|
||||||
|
|
||||||
|
fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
|
||||||
|
LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
|
||||||
|
return b
|
@ -299,3 +299,54 @@ def guess_task_from_head(head):
|
|||||||
raise SyntaxError("task or model not recognized! Please refer the docs at : ") # TODO: add docs links
|
raise SyntaxError("task or model not recognized! Please refer the docs at : ") # TODO: add docs links
|
||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
def profile(input, ops, n=10, device=None):
|
||||||
|
""" YOLOv5 speed/memory/FLOPs profiler
|
||||||
|
Usage:
|
||||||
|
input = torch.randn(16, 3, 640, 640)
|
||||||
|
m1 = lambda x: x * torch.sigmoid(x)
|
||||||
|
m2 = nn.SiLU()
|
||||||
|
profile(input, [m1, m2], n=100) # profile over 100 iterations
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
if not isinstance(device, torch.device):
|
||||||
|
device = select_device(device)
|
||||||
|
print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
||||||
|
f"{'input':>24s}{'output':>24s}")
|
||||||
|
|
||||||
|
for x in input if isinstance(input, list) else [input]:
|
||||||
|
x = x.to(device)
|
||||||
|
x.requires_grad = True
|
||||||
|
for m in ops if isinstance(ops, list) else [ops]:
|
||||||
|
m = m.to(device) if hasattr(m, 'to') else m # device
|
||||||
|
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
||||||
|
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
|
||||||
|
try:
|
||||||
|
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
|
||||||
|
except Exception:
|
||||||
|
flops = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
for _ in range(n):
|
||||||
|
t[0] = time_sync()
|
||||||
|
y = m(x)
|
||||||
|
t[1] = time_sync()
|
||||||
|
try:
|
||||||
|
_ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
|
||||||
|
t[2] = time_sync()
|
||||||
|
except Exception: # no backward method
|
||||||
|
# print(e) # for debug
|
||||||
|
t[2] = float('nan')
|
||||||
|
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
||||||
|
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
||||||
|
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
|
||||||
|
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
|
||||||
|
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
|
||||||
|
print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
||||||
|
results.append([p, flops, mem, tf, tb, s_in, s_out])
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
results.append(None)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return results
|
||||||
|
@ -74,9 +74,16 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
return self.compute_loss(preds, batch)
|
return self.compute_loss(preds, batch)
|
||||||
|
|
||||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||||
# We should just use named tensors here in future
|
"""
|
||||||
|
Returns a loss dict with labelled training loss items tensor
|
||||||
|
"""
|
||||||
|
# Not needed for classification but necessary for segmentation & detection
|
||||||
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||||
return dict(zip(keys, loss_items)) if loss_items is not None else keys
|
if loss_items is not None:
|
||||||
|
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
|
||||||
|
return dict(zip(keys, loss_items))
|
||||||
|
else:
|
||||||
|
return keys
|
||||||
|
|
||||||
def progress_string(self):
|
def progress_string(self):
|
||||||
return ('\n' + '%11s' *
|
return ('\n' + '%11s' *
|
||||||
|
Loading…
x
Reference in New Issue
Block a user