mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
added attribution visualization with run_attribution.py
This commit is contained in:
parent
cd2f79c702
commit
b02bf58de6
1
.gitignore
vendored
1
.gitignore
vendored
@ -51,6 +51,7 @@ coverage.xml
|
|||||||
.hypothesis/
|
.hypothesis/
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
mlruns/
|
mlruns/
|
||||||
|
figures/
|
||||||
|
|
||||||
# Translations
|
# Translations
|
||||||
*.mo
|
*.mo
|
||||||
|
34
run_attribution.py
Normal file
34
run_attribution.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from ultralytics import YOLOv10, YOLO
|
||||||
|
# from ultralytics.engine.pgt_trainer import PGTTrainer
|
||||||
|
# from ultralytics import BaseTrainer
|
||||||
|
# from ultralytics.engine.trainer import BaseTrainer
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Set CUDA device (only needed for multi-gpu machines)
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
|
||||||
|
|
||||||
|
# model = YOLOv10()
|
||||||
|
# model = YOLO()
|
||||||
|
# If you want to finetune the model with pretrained weights, you could load the
|
||||||
|
# pretrained weights like below
|
||||||
|
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
|
||||||
|
# or
|
||||||
|
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
||||||
|
model = YOLOv10('yolov10n.pt')
|
||||||
|
|
||||||
|
model.train(data='coco.yaml',
|
||||||
|
trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT)
|
||||||
|
# Add return_images as input parameter
|
||||||
|
epochs=500, batch=16, imgsz=640,
|
||||||
|
debug=True, # If debug = True, the attributions will be saved in the figures folder
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the trained model
|
||||||
|
model.save('yolov10_coco_trained.pt')
|
||||||
|
|
||||||
|
# Evaluate the model on the validation set
|
||||||
|
results = model.val(data='coco.yaml')
|
||||||
|
|
||||||
|
# Print the evaluation results
|
||||||
|
print(results)
|
83
run_train.py
Normal file
83
run_train.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
from ultralytics import YOLOv10, YOLO
|
||||||
|
# from ultralytics.engine.pgt_trainer import PGTTrainer
|
||||||
|
# from ultralytics import BaseTrainer
|
||||||
|
# from ultralytics.engine.trainer import BaseTrainer
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Set CUDA device (only needed for multi-gpu machines)
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
|
||||||
|
|
||||||
|
model = YOLOv10()
|
||||||
|
# model = YOLO()
|
||||||
|
# If you want to finetune the model with pretrained weights, you could load the
|
||||||
|
# pretrained weights like below
|
||||||
|
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
|
||||||
|
# or
|
||||||
|
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
||||||
|
# model = YOLOv10('yolov10m.pt')
|
||||||
|
|
||||||
|
model.train(data='coco.yaml',
|
||||||
|
# Add return_images as input parameter
|
||||||
|
epochs=500, batch=16, imgsz=640,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the trained model
|
||||||
|
model.save('yolov10_coco_trained.pt')
|
||||||
|
|
||||||
|
# Evaluate the model on the validation set
|
||||||
|
results = model.val(data='coco.yaml')
|
||||||
|
|
||||||
|
# Print the evaluation results
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
# import torch
|
||||||
|
# from torch.utils.data import DataLoader
|
||||||
|
# from torchvision import datasets, transforms
|
||||||
|
|
||||||
|
# # Define the transformation for the dataset
|
||||||
|
# transform = transforms.Compose([
|
||||||
|
# transforms.Resize((640, 640)),
|
||||||
|
# transforms.ToTensor()
|
||||||
|
# ])
|
||||||
|
|
||||||
|
# # Load the COCO dataset
|
||||||
|
# train_dataset = datasets.CocoDetection(root='data/nielseni6/coco/train2017', annFile='/data/nielseni6/coco/annotations/instances_train2017.json', transform=transform)
|
||||||
|
# val_dataset = datasets.CocoDetection(root='data/nielseni6/coco/val2017', annFile='/data/nielseni6/coco/annotations/instances_val2017.json', transform=transform)
|
||||||
|
|
||||||
|
# # Create data loaders
|
||||||
|
# train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
|
||||||
|
# val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)
|
||||||
|
|
||||||
|
# model = YOLOv10()
|
||||||
|
|
||||||
|
# # Define the optimizer
|
||||||
|
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
# # Training loop
|
||||||
|
# for epoch in range(500):
|
||||||
|
# model.train()
|
||||||
|
# for images, targets in train_loader:
|
||||||
|
# images = images.to('cuda')
|
||||||
|
# targets = [{k: v.to('cuda') for k, v in t.items()} for t in targets]
|
||||||
|
# loss = model(images, targets)
|
||||||
|
# loss.backward()
|
||||||
|
# optimizer.step()
|
||||||
|
# optimizer.zero_grad()
|
||||||
|
|
||||||
|
# # Validation loop
|
||||||
|
# model.eval()
|
||||||
|
# with torch.no_grad():
|
||||||
|
# for images, targets in val_loader:
|
||||||
|
# images = images.to('cuda')
|
||||||
|
# targets = [{k: v.to('cuda') for k, v in t.items()} for t in targets]
|
||||||
|
# results = model(images, targets)
|
||||||
|
|
||||||
|
# # Save the trained model
|
||||||
|
# model.save('yolov10_coco_trained.pt')
|
||||||
|
|
||||||
|
# # Evaluate the model on the validation set
|
||||||
|
# results = model.val(data='coco.yaml')
|
||||||
|
|
||||||
|
# # Print the evaluation results
|
||||||
|
# print(results)
|
96
run_val.py
Normal file
96
run_val.py
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
from ultralytics import YOLOv10
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
# Define the device
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
|
||||||
|
# model = YOLOv10.from_pretrained('jameslahm/yolov10n')
|
||||||
|
# or
|
||||||
|
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
||||||
|
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10n.pt
|
||||||
|
# model = YOLOv10('yolov10{n/s/m/b/l/x}.pt')
|
||||||
|
model = YOLOv10('yolov10n.pt').to(device)
|
||||||
|
|
||||||
|
# Load the image
|
||||||
|
# path = '/home/nielseni6/PythonScripts/Github/yolov10/images/fat-dog.jpg'
|
||||||
|
path = '/home/nielseni6/PythonScripts/Github/yolov10/images/The-Cardinal-Bird.jpg'
|
||||||
|
image = Image.open(path)
|
||||||
|
|
||||||
|
# Define the transformation to resize the image, convert it to a tensor, and normalize it
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((640, 640)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
|
||||||
|
# Apply the transformation
|
||||||
|
image_tensor = transform(image)
|
||||||
|
|
||||||
|
# Add a batch dimension
|
||||||
|
image_tensor = image_tensor.unsqueeze(0).to(device)
|
||||||
|
image_tensor = image_tensor.requires_grad_(True)
|
||||||
|
|
||||||
|
|
||||||
|
# Predict for a specific image
|
||||||
|
# results = model.predict(image_tensor, save=True)
|
||||||
|
# model.requires_grad_(True)
|
||||||
|
|
||||||
|
|
||||||
|
# for p in model.parameters():
|
||||||
|
# p.requires_grad = True
|
||||||
|
results = model.predict(image_tensor, save=True)
|
||||||
|
|
||||||
|
# Display the results
|
||||||
|
for result in results:
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# pred = results[0].boxes[0].conf
|
||||||
|
|
||||||
|
# # Hook to store the activations
|
||||||
|
# activations = {}
|
||||||
|
|
||||||
|
# def get_activation(name):
|
||||||
|
# def hook(model, input, output):
|
||||||
|
# activations[name] = output
|
||||||
|
# return hook
|
||||||
|
|
||||||
|
# # Register hooks for each layer you want to inspect
|
||||||
|
# for name, layer in model.model.named_modules():
|
||||||
|
# layer.register_forward_hook(get_activation(name))
|
||||||
|
|
||||||
|
# # Run the model to get activations
|
||||||
|
# results = model.predict(image_tensor, save=True, visualize=True)
|
||||||
|
|
||||||
|
# # # Print the activations
|
||||||
|
# # for name, activation in activations.items():
|
||||||
|
# # print(f"Activation from layer {name}: {activation}")
|
||||||
|
|
||||||
|
# # List activation names separately
|
||||||
|
# print("\nActivation layer names:")
|
||||||
|
# for name in activations.keys():
|
||||||
|
# print(name)
|
||||||
|
# # pred.backward()
|
||||||
|
|
||||||
|
# # Assuming 'model.23' is the layer of interest for bbox prediction and confidence
|
||||||
|
# activation = activations['model.23']['one2one'][0]
|
||||||
|
# act_23 = activations['model.23.cv3.2']
|
||||||
|
# act_dfl = activations['model.23.dfl.conv']
|
||||||
|
# act_conv = activations['model.0.conv']
|
||||||
|
# act_act = activations['model.0.act']
|
||||||
|
|
||||||
|
# # with torch.autograd.set_detect_anomaly(True):
|
||||||
|
# # pred.backward()
|
||||||
|
# grad = torch.autograd.grad(act_23, im, grad_outputs=torch.ones_like(act_23), create_graph=True, retain_graph=True)[0]
|
||||||
|
# # grad = torch.autograd.grad(pred, im, grad_outputs=torch.ones_like(pred), create_graph=True)[0]
|
||||||
|
# grad = torch.autograd.grad(activations['model.23']['one2one'][1][0],
|
||||||
|
# activations['model.23.one2one_cv3.2'],
|
||||||
|
# grad_outputs=torch.ones_like(activations['model.23']['one2one'][1][0]),
|
||||||
|
# create_graph=True, retain_graph=True)[0]
|
||||||
|
|
||||||
|
# # Print the results
|
||||||
|
# print(results)
|
||||||
|
|
||||||
|
# model.val(data='coco.yaml', batch=256)
|
@ -387,6 +387,7 @@ class Model(nn.Module):
|
|||||||
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
|
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
predictor=None,
|
predictor=None,
|
||||||
|
return_images: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list:
|
) -> list:
|
||||||
"""
|
"""
|
||||||
@ -438,7 +439,7 @@ class Model(nn.Module):
|
|||||||
self.predictor.save_dir = get_save_dir(self.predictor.args)
|
self.predictor.save_dir = get_save_dir(self.predictor.args)
|
||||||
if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models
|
if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models
|
||||||
self.predictor.set_prompts(prompts)
|
self.predictor.set_prompts(prompts)
|
||||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream, return_images=return_images)
|
||||||
|
|
||||||
def track(
|
def track(
|
||||||
self,
|
self,
|
||||||
@ -590,6 +591,81 @@ class Model(nn.Module):
|
|||||||
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
|
self,
|
||||||
|
trainer=None,
|
||||||
|
debug=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Trains the model using the specified dataset and training configuration.
|
||||||
|
|
||||||
|
This method facilitates model training with a range of customizable settings and configurations. It supports
|
||||||
|
training with a custom trainer or the default training approach defined in the method. The method handles
|
||||||
|
different scenarios, such as resuming training from a checkpoint, integrating with Ultralytics HUB, and
|
||||||
|
updating model and configuration after training.
|
||||||
|
|
||||||
|
When using Ultralytics HUB, if the session already has a loaded model, the method prioritizes HUB training
|
||||||
|
arguments and issues a warning if local arguments are provided. It checks for pip updates and combines default
|
||||||
|
configurations, method-specific defaults, and user-provided arguments to configure the training process. After
|
||||||
|
training, it updates the model and its configurations, and optionally attaches metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trainer (BaseTrainer, optional): An instance of a custom trainer class for training the model. If None, the
|
||||||
|
method uses a default trainer. Defaults to None.
|
||||||
|
**kwargs (any): Arbitrary keyword arguments representing the training configuration. These arguments are
|
||||||
|
used to customize various aspects of the training process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(dict | None): Training metrics if available and training is successful; otherwise, None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the model is not a PyTorch model.
|
||||||
|
PermissionError: If there is a permission issue with the HUB session.
|
||||||
|
ModuleNotFoundError: If the HUB SDK is not installed.
|
||||||
|
"""
|
||||||
|
self._check_is_pytorch_model()
|
||||||
|
if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
|
||||||
|
if any(kwargs):
|
||||||
|
LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.")
|
||||||
|
kwargs = self.session.train_args # overwrite kwargs
|
||||||
|
|
||||||
|
checks.check_pip_update_available()
|
||||||
|
|
||||||
|
overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
|
||||||
|
custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]} # method defaults
|
||||||
|
args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
|
||||||
|
if args.get("resume"):
|
||||||
|
args["resume"] = self.ckpt_path
|
||||||
|
|
||||||
|
self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
|
||||||
|
if not args.get("resume"): # manually set model only if not resuming
|
||||||
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||||
|
self.model = self.trainer.model
|
||||||
|
|
||||||
|
if SETTINGS["hub"] is True and not self.session:
|
||||||
|
# Create a model in HUB
|
||||||
|
try:
|
||||||
|
self.session = self._get_hub_session(self.model_name)
|
||||||
|
if self.session:
|
||||||
|
self.session.create_model(args)
|
||||||
|
# Check model was created
|
||||||
|
if not getattr(self.session.model, "id", None):
|
||||||
|
self.session = None
|
||||||
|
except (PermissionError, ModuleNotFoundError):
|
||||||
|
# Ignore PermissionError and ModuleNotFoundError which indicates hub-sdk not installed
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.trainer.hub_session = self.session # attach optional HUB session
|
||||||
|
self.trainer.train(debug=debug)
|
||||||
|
# Update model and cfg after training
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
||||||
|
self.model, _ = attempt_load_one_weight(ckpt)
|
||||||
|
self.overrides = self.model.args
|
||||||
|
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
|
||||||
|
return self.metrics
|
||||||
|
|
||||||
|
def train_pgt(
|
||||||
self,
|
self,
|
||||||
trainer=None,
|
trainer=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -662,7 +738,7 @@ class Model(nn.Module):
|
|||||||
self.overrides = self.model.args
|
self.overrides = self.model.args
|
||||||
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
|
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
|
||||||
return self.metrics
|
return self.metrics
|
||||||
|
|
||||||
def tune(
|
def tune(
|
||||||
self,
|
self,
|
||||||
use_ray=False,
|
use_ray=False,
|
||||||
|
785
ultralytics/engine/pgt_trainer.py
Normal file
785
ultralytics/engine/pgt_trainer.py
Normal file
@ -0,0 +1,785 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
"""
|
||||||
|
Train a model on a dataset.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from copy import deepcopy
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import distributed as dist
|
||||||
|
from torch import nn, optim
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torchvision.transforms as T
|
||||||
|
|
||||||
|
from ultralytics.cfg import get_cfg, get_save_dir
|
||||||
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||||
|
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||||
|
from ultralytics.utils import (
|
||||||
|
DEFAULT_CFG,
|
||||||
|
LOGGER,
|
||||||
|
RANK,
|
||||||
|
TQDM,
|
||||||
|
__version__,
|
||||||
|
callbacks,
|
||||||
|
clean_url,
|
||||||
|
colorstr,
|
||||||
|
emojis,
|
||||||
|
yaml_save,
|
||||||
|
)
|
||||||
|
from ultralytics.utils.autobatch import check_train_batch_size
|
||||||
|
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
|
||||||
|
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
||||||
|
from ultralytics.utils.files import get_latest_run
|
||||||
|
from ultralytics.utils.torch_utils import (
|
||||||
|
EarlyStopping,
|
||||||
|
ModelEMA,
|
||||||
|
de_parallel,
|
||||||
|
init_seeds,
|
||||||
|
one_cycle,
|
||||||
|
select_device,
|
||||||
|
strip_optimizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PGTTrainer:
|
||||||
|
"""
|
||||||
|
BaseTrainer.
|
||||||
|
|
||||||
|
A base class for creating trainers.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
args (SimpleNamespace): Configuration for the trainer.
|
||||||
|
validator (BaseValidator): Validator instance.
|
||||||
|
model (nn.Module): Model instance.
|
||||||
|
callbacks (defaultdict): Dictionary of callbacks.
|
||||||
|
save_dir (Path): Directory to save results.
|
||||||
|
wdir (Path): Directory to save weights.
|
||||||
|
last (Path): Path to the last checkpoint.
|
||||||
|
best (Path): Path to the best checkpoint.
|
||||||
|
save_period (int): Save checkpoint every x epochs (disabled if < 1).
|
||||||
|
batch_size (int): Batch size for training.
|
||||||
|
epochs (int): Number of epochs to train for.
|
||||||
|
start_epoch (int): Starting epoch for training.
|
||||||
|
device (torch.device): Device to use for training.
|
||||||
|
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
|
||||||
|
scaler (amp.GradScaler): Gradient scaler for AMP.
|
||||||
|
data (str): Path to data.
|
||||||
|
trainset (torch.utils.data.Dataset): Training dataset.
|
||||||
|
testset (torch.utils.data.Dataset): Testing dataset.
|
||||||
|
ema (nn.Module): EMA (Exponential Moving Average) of the model.
|
||||||
|
resume (bool): Resume training from a checkpoint.
|
||||||
|
lf (nn.Module): Loss function.
|
||||||
|
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
|
||||||
|
best_fitness (float): The best fitness value achieved.
|
||||||
|
fitness (float): Current fitness value.
|
||||||
|
loss (float): Current loss value.
|
||||||
|
tloss (float): Total loss value.
|
||||||
|
loss_names (list): List of loss names.
|
||||||
|
csv (Path): Path to results CSV file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
|
"""
|
||||||
|
Initializes the BaseTrainer class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||||
|
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||||
|
"""
|
||||||
|
self.args = get_cfg(cfg, overrides)
|
||||||
|
self.check_resume(overrides)
|
||||||
|
self.device = select_device(self.args.device, self.args.batch)
|
||||||
|
self.validator = None
|
||||||
|
self.metrics = None
|
||||||
|
self.plots = {}
|
||||||
|
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||||
|
|
||||||
|
# Dirs
|
||||||
|
self.save_dir = get_save_dir(self.args)
|
||||||
|
self.args.name = self.save_dir.name # update name for loggers
|
||||||
|
self.wdir = self.save_dir / "weights" # weights dir
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||||
|
self.args.save_dir = str(self.save_dir)
|
||||||
|
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
|
||||||
|
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
|
||||||
|
self.save_period = self.args.save_period
|
||||||
|
|
||||||
|
self.batch_size = self.args.batch
|
||||||
|
self.epochs = self.args.epochs
|
||||||
|
self.start_epoch = 0
|
||||||
|
if RANK == -1:
|
||||||
|
print_args(vars(self.args))
|
||||||
|
|
||||||
|
# Device
|
||||||
|
if self.device.type in ("cpu", "mps"):
|
||||||
|
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||||
|
|
||||||
|
# Model and Dataset
|
||||||
|
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||||
|
try:
|
||||||
|
if self.args.task == "classify":
|
||||||
|
self.data = check_cls_dataset(self.args.data)
|
||||||
|
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
||||||
|
"detect",
|
||||||
|
"segment",
|
||||||
|
"pose",
|
||||||
|
"obb",
|
||||||
|
):
|
||||||
|
self.data = check_det_dataset(self.args.data)
|
||||||
|
if "yaml_file" in self.data:
|
||||||
|
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
||||||
|
|
||||||
|
self.trainset, self.testset = self.get_dataset(self.data)
|
||||||
|
self.ema = None
|
||||||
|
|
||||||
|
# Optimization utils init
|
||||||
|
self.lf = None
|
||||||
|
self.scheduler = None
|
||||||
|
|
||||||
|
# Epoch level metrics
|
||||||
|
self.best_fitness = None
|
||||||
|
self.fitness = None
|
||||||
|
self.loss = None
|
||||||
|
self.tloss = None
|
||||||
|
self.loss_names = ["Loss"]
|
||||||
|
self.csv = self.save_dir / "results.csv"
|
||||||
|
self.plot_idx = [0, 1, 2]
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
|
def add_callback(self, event: str, callback):
|
||||||
|
"""Appends the given callback."""
|
||||||
|
self.callbacks[event].append(callback)
|
||||||
|
|
||||||
|
def set_callback(self, event: str, callback):
|
||||||
|
"""Overrides the existing callbacks with the given callback."""
|
||||||
|
self.callbacks[event] = [callback]
|
||||||
|
|
||||||
|
def run_callbacks(self, event: str):
|
||||||
|
"""Run all existing callbacks associated with a particular event."""
|
||||||
|
for callback in self.callbacks.get(event, []):
|
||||||
|
callback(self)
|
||||||
|
|
||||||
|
def train(self, debug=False):
|
||||||
|
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
||||||
|
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
||||||
|
world_size = len(self.args.device.split(","))
|
||||||
|
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
||||||
|
world_size = len(self.args.device)
|
||||||
|
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
||||||
|
world_size = 1 # default to device 0
|
||||||
|
else: # i.e. device='cpu' or 'mps'
|
||||||
|
world_size = 0
|
||||||
|
|
||||||
|
# Run subprocess if DDP training, else train normally
|
||||||
|
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
||||||
|
# Argument checks
|
||||||
|
if self.args.rect:
|
||||||
|
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
|
||||||
|
self.args.rect = False
|
||||||
|
if self.args.batch == -1:
|
||||||
|
LOGGER.warning(
|
||||||
|
"WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
|
||||||
|
"default 'batch=16'"
|
||||||
|
)
|
||||||
|
self.args.batch = 16
|
||||||
|
|
||||||
|
# Command
|
||||||
|
cmd, file = generate_ddp_command(world_size, self)
|
||||||
|
try:
|
||||||
|
LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
ddp_cleanup(self, str(file))
|
||||||
|
|
||||||
|
else:
|
||||||
|
self._do_train(world_size, debug=debug)
|
||||||
|
|
||||||
|
def _setup_scheduler(self):
|
||||||
|
"""Initialize training learning rate scheduler."""
|
||||||
|
if self.args.cos_lr:
|
||||||
|
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
||||||
|
else:
|
||||||
|
self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||||
|
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||||
|
|
||||||
|
def _setup_ddp(self, world_size):
|
||||||
|
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
||||||
|
torch.cuda.set_device(RANK)
|
||||||
|
self.device = torch.device("cuda", RANK)
|
||||||
|
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||||
|
os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="nccl" if dist.is_nccl_available() else "gloo",
|
||||||
|
timeout=timedelta(seconds=10800), # 3 hours
|
||||||
|
rank=RANK,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _setup_train(self, world_size):
|
||||||
|
"""Builds dataloaders and optimizer on correct rank process."""
|
||||||
|
|
||||||
|
# Model
|
||||||
|
self.run_callbacks("on_pretrain_routine_start")
|
||||||
|
ckpt = self.setup_model()
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
self.set_model_attributes()
|
||||||
|
|
||||||
|
# Freeze layers
|
||||||
|
freeze_list = (
|
||||||
|
self.args.freeze
|
||||||
|
if isinstance(self.args.freeze, list)
|
||||||
|
else range(self.args.freeze)
|
||||||
|
if isinstance(self.args.freeze, int)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
always_freeze_names = [".dfl"] # always freeze these layers
|
||||||
|
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
|
||||||
|
for k, v in self.model.named_parameters():
|
||||||
|
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
|
||||||
|
if any(x in k for x in freeze_layer_names):
|
||||||
|
LOGGER.info(f"Freezing layer '{k}'")
|
||||||
|
v.requires_grad = False
|
||||||
|
elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients
|
||||||
|
LOGGER.info(
|
||||||
|
f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
|
||||||
|
"See ultralytics.engine.trainer for customization of frozen layers."
|
||||||
|
)
|
||||||
|
v.requires_grad = True
|
||||||
|
|
||||||
|
# Check AMP
|
||||||
|
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
||||||
|
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
|
||||||
|
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
||||||
|
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
||||||
|
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
||||||
|
if RANK > -1 and world_size > 1: # DDP
|
||||||
|
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
||||||
|
self.amp = bool(self.amp) # as boolean
|
||||||
|
self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
|
||||||
|
if world_size > 1:
|
||||||
|
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
|
||||||
|
|
||||||
|
# Check imgsz
|
||||||
|
gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
|
||||||
|
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
|
||||||
|
self.stride = gs # for multiscale training
|
||||||
|
|
||||||
|
# Batch size
|
||||||
|
if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size
|
||||||
|
self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
batch_size = self.batch_size // max(world_size, 1)
|
||||||
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
||||||
|
self.test_loader = self.get_dataloader(
|
||||||
|
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
|
||||||
|
)
|
||||||
|
self.validator = self.get_validator()
|
||||||
|
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
|
||||||
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
|
||||||
|
self.ema = ModelEMA(self.model)
|
||||||
|
if self.args.plots:
|
||||||
|
self.plot_training_labels()
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||||
|
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||||
|
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
|
||||||
|
self.optimizer = self.build_optimizer(
|
||||||
|
model=self.model,
|
||||||
|
name=self.args.optimizer,
|
||||||
|
lr=self.args.lr0,
|
||||||
|
momentum=self.args.momentum,
|
||||||
|
decay=weight_decay,
|
||||||
|
iterations=iterations,
|
||||||
|
)
|
||||||
|
# Scheduler
|
||||||
|
self._setup_scheduler()
|
||||||
|
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
||||||
|
self.resume_training(ckpt)
|
||||||
|
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||||
|
self.run_callbacks("on_pretrain_routine_end")
|
||||||
|
|
||||||
|
def _do_train(self, world_size=1, debug=False):
|
||||||
|
"""Train completed, evaluate and plot if specified by arguments."""
|
||||||
|
if world_size > 1:
|
||||||
|
self._setup_ddp(world_size)
|
||||||
|
self._setup_train(world_size)
|
||||||
|
|
||||||
|
nb = len(self.train_loader) # number of batches
|
||||||
|
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
|
||||||
|
last_opt_step = -1
|
||||||
|
self.epoch_time = None
|
||||||
|
self.epoch_time_start = time.time()
|
||||||
|
self.train_time_start = time.time()
|
||||||
|
self.run_callbacks("on_train_start")
|
||||||
|
LOGGER.info(
|
||||||
|
f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||||
|
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||||
|
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||||
|
f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
|
||||||
|
)
|
||||||
|
if self.args.close_mosaic:
|
||||||
|
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||||
|
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
||||||
|
epoch = self.start_epoch
|
||||||
|
while True:
|
||||||
|
self.epoch = epoch
|
||||||
|
self.run_callbacks("on_train_epoch_start")
|
||||||
|
self.model.train()
|
||||||
|
if RANK != -1:
|
||||||
|
self.train_loader.sampler.set_epoch(epoch)
|
||||||
|
pbar = enumerate(self.train_loader)
|
||||||
|
# Update dataloader attributes (optional)
|
||||||
|
if epoch == (self.epochs - self.args.close_mosaic):
|
||||||
|
self._close_dataloader_mosaic()
|
||||||
|
self.train_loader.reset()
|
||||||
|
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
LOGGER.info(self.progress_string())
|
||||||
|
pbar = TQDM(enumerate(self.train_loader), total=nb)
|
||||||
|
self.tloss = None
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
for i, batch in pbar:
|
||||||
|
self.run_callbacks("on_train_batch_start")
|
||||||
|
# Warmup
|
||||||
|
ni = i + nb * epoch
|
||||||
|
if ni <= nw:
|
||||||
|
xi = [0, nw] # x interp
|
||||||
|
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
|
||||||
|
for j, x in enumerate(self.optimizer.param_groups):
|
||||||
|
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
||||||
|
x["lr"] = np.interp(
|
||||||
|
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
|
||||||
|
)
|
||||||
|
if "momentum" in x:
|
||||||
|
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
||||||
|
|
||||||
|
# Forward
|
||||||
|
with torch.cuda.amp.autocast(self.amp):
|
||||||
|
batch = self.preprocess_batch(batch)
|
||||||
|
(self.loss, self.loss_items), images = self.model(batch, return_images=True)
|
||||||
|
|
||||||
|
if debug and (i % 250):
|
||||||
|
grad = torch.autograd.grad(self.loss, images, create_graph=True)[0]
|
||||||
|
# Convert tensors to numpy arrays
|
||||||
|
images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
|
||||||
|
# Normalize grad for visualization
|
||||||
|
grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
|
||||||
|
|
||||||
|
for ix in range(images_np.shape[0]):
|
||||||
|
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
|
||||||
|
ax[0].imshow(images_np[i])
|
||||||
|
ax[0].set_title('Image')
|
||||||
|
ax[1].imshow(grad_np[i], cmap='jet')
|
||||||
|
ax[1].set_title('Gradient')
|
||||||
|
ax[2].imshow(images_np[i])
|
||||||
|
ax[2].imshow(grad_np[i], cmap='jet', alpha=0.5)
|
||||||
|
ax[2].set_title('Overlay')
|
||||||
|
|
||||||
|
save_dir_attr = "figures/attributions"
|
||||||
|
if not os.path.exists(save_dir_attr):
|
||||||
|
os.makedirs(save_dir_attr)
|
||||||
|
plt.savefig(f'{save_dir_attr}/debug_epoch_{epoch}_batch_{i}_image_{ix}.png')
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
if RANK != -1:
|
||||||
|
self.loss *= world_size
|
||||||
|
self.tloss = (
|
||||||
|
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backward
|
||||||
|
self.scaler.scale(self.loss).backward()
|
||||||
|
|
||||||
|
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
|
||||||
|
if ni - last_opt_step >= self.accumulate:
|
||||||
|
self.optimizer_step()
|
||||||
|
last_opt_step = ni
|
||||||
|
|
||||||
|
# Timed stopping
|
||||||
|
if self.args.time:
|
||||||
|
self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
|
||||||
|
if RANK != -1: # if DDP training
|
||||||
|
broadcast_list = [self.stop if RANK == 0 else None]
|
||||||
|
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
||||||
|
self.stop = broadcast_list[0]
|
||||||
|
if self.stop: # training time exceeded
|
||||||
|
break
|
||||||
|
|
||||||
|
# Log
|
||||||
|
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
|
||||||
|
loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
|
||||||
|
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
pbar.set_description(
|
||||||
|
("%11s" * 2 + "%11.4g" * (2 + loss_len))
|
||||||
|
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
|
||||||
|
)
|
||||||
|
self.run_callbacks("on_batch_end")
|
||||||
|
if self.args.plots and ni in self.plot_idx:
|
||||||
|
self.plot_training_samples(batch, ni)
|
||||||
|
|
||||||
|
self.run_callbacks("on_train_batch_end")
|
||||||
|
|
||||||
|
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||||
|
self.run_callbacks("on_train_epoch_end")
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
final_epoch = epoch + 1 == self.epochs
|
||||||
|
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
if (self.args.val and (((epoch+1) % self.args.val_period == 0) or (self.epochs - epoch) <= 10)) \
|
||||||
|
or final_epoch or self.stopper.possible_stop or self.stop:
|
||||||
|
self.metrics, self.fitness = self.validate()
|
||||||
|
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
|
||||||
|
self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
|
||||||
|
if self.args.time:
|
||||||
|
self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
|
||||||
|
|
||||||
|
# Save model
|
||||||
|
if self.args.save or final_epoch:
|
||||||
|
self.save_model()
|
||||||
|
self.run_callbacks("on_model_save")
|
||||||
|
|
||||||
|
# Scheduler
|
||||||
|
t = time.time()
|
||||||
|
self.epoch_time = t - self.epoch_time_start
|
||||||
|
self.epoch_time_start = t
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
||||||
|
if self.args.time:
|
||||||
|
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
|
||||||
|
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
|
||||||
|
self._setup_scheduler()
|
||||||
|
self.scheduler.last_epoch = self.epoch # do not move
|
||||||
|
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
||||||
|
self.scheduler.step()
|
||||||
|
self.run_callbacks("on_fit_epoch_end")
|
||||||
|
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
|
||||||
|
|
||||||
|
# Early Stopping
|
||||||
|
if RANK != -1: # if DDP training
|
||||||
|
broadcast_list = [self.stop if RANK == 0 else None]
|
||||||
|
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
||||||
|
self.stop = broadcast_list[0]
|
||||||
|
if self.stop:
|
||||||
|
break # must break all DDP ranks
|
||||||
|
epoch += 1
|
||||||
|
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
# Do final val with best.pt
|
||||||
|
LOGGER.info(
|
||||||
|
f"\n{epoch - self.start_epoch + 1} epochs completed in "
|
||||||
|
f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
|
||||||
|
)
|
||||||
|
self.final_eval()
|
||||||
|
if self.args.plots:
|
||||||
|
self.plot_metrics()
|
||||||
|
self.run_callbacks("on_train_end")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.run_callbacks("teardown")
|
||||||
|
|
||||||
|
def save_model(self):
|
||||||
|
"""Save model training checkpoints with additional metadata."""
|
||||||
|
import pandas as pd # scope for faster startup
|
||||||
|
|
||||||
|
metrics = {**self.metrics, **{"fitness": self.fitness}}
|
||||||
|
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
|
||||||
|
ckpt = {
|
||||||
|
"epoch": self.epoch,
|
||||||
|
"best_fitness": self.best_fitness,
|
||||||
|
"model": deepcopy(de_parallel(self.model)).half(),
|
||||||
|
"ema": deepcopy(self.ema.ema).half(),
|
||||||
|
"updates": self.ema.updates,
|
||||||
|
"optimizer": self.optimizer.state_dict(),
|
||||||
|
"train_args": vars(self.args), # save as dict
|
||||||
|
"train_metrics": metrics,
|
||||||
|
"train_results": results,
|
||||||
|
"date": datetime.now().isoformat(),
|
||||||
|
"version": __version__,
|
||||||
|
"license": "AGPL-3.0 (https://ultralytics.com/license)",
|
||||||
|
"docs": "https://docs.ultralytics.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save last and best
|
||||||
|
torch.save(ckpt, self.last)
|
||||||
|
if self.best_fitness == self.fitness:
|
||||||
|
torch.save(ckpt, self.best)
|
||||||
|
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
||||||
|
torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_dataset(data):
|
||||||
|
"""
|
||||||
|
Get train, val path from data dict if it exists.
|
||||||
|
|
||||||
|
Returns None if data format is not recognized.
|
||||||
|
"""
|
||||||
|
return data["train"], data.get("val") or data.get("test")
|
||||||
|
|
||||||
|
def setup_model(self):
|
||||||
|
"""Load/create/download model for any task."""
|
||||||
|
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||||
|
return
|
||||||
|
|
||||||
|
model, weights = self.model, None
|
||||||
|
ckpt = None
|
||||||
|
if str(model).endswith(".pt"):
|
||||||
|
weights, ckpt = attempt_load_one_weight(model)
|
||||||
|
cfg = ckpt["model"].yaml
|
||||||
|
else:
|
||||||
|
cfg = model
|
||||||
|
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
||||||
|
return ckpt
|
||||||
|
|
||||||
|
def optimizer_step(self):
|
||||||
|
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
|
||||||
|
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
if self.ema:
|
||||||
|
self.ema.update(self.model)
|
||||||
|
|
||||||
|
def preprocess_batch(self, batch):
|
||||||
|
"""Allows custom preprocessing model inputs and ground truths depending on task type."""
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""
|
||||||
|
Runs validation on test set using self.validator.
|
||||||
|
|
||||||
|
The returned dict is expected to contain "fitness" key.
|
||||||
|
"""
|
||||||
|
metrics = self.validator(self)
|
||||||
|
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||||
|
if not self.best_fitness or self.best_fitness < fitness:
|
||||||
|
self.best_fitness = fitness
|
||||||
|
return metrics, fitness
|
||||||
|
|
||||||
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
|
"""Get model and raise NotImplementedError for loading cfg files."""
|
||||||
|
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||||
|
|
||||||
|
def get_validator(self):
|
||||||
|
"""Returns a NotImplementedError when the get_validator function is called."""
|
||||||
|
raise NotImplementedError("get_validator function not implemented in trainer")
|
||||||
|
|
||||||
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||||
|
"""Returns dataloader derived from torch.data.Dataloader."""
|
||||||
|
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
||||||
|
|
||||||
|
def build_dataset(self, img_path, mode="train", batch=None):
|
||||||
|
"""Build dataset."""
|
||||||
|
raise NotImplementedError("build_dataset function not implemented in trainer")
|
||||||
|
|
||||||
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||||
|
"""
|
||||||
|
Returns a loss dict with labelled training loss items tensor.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This is not needed for classification but necessary for segmentation & detection
|
||||||
|
"""
|
||||||
|
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
||||||
|
|
||||||
|
def set_model_attributes(self):
|
||||||
|
"""To set or update model parameters before training."""
|
||||||
|
self.model.names = self.data["names"]
|
||||||
|
|
||||||
|
def build_targets(self, preds, targets):
|
||||||
|
"""Builds target tensors for training YOLO model."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def progress_string(self):
|
||||||
|
"""Returns a string describing training progress."""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# TODO: may need to put these following functions into callback
|
||||||
|
def plot_training_samples(self, batch, ni):
|
||||||
|
"""Plots training samples during YOLO training."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def plot_training_labels(self):
|
||||||
|
"""Plots training labels for YOLO model."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_metrics(self, metrics):
|
||||||
|
"""Saves training metrics to a CSV file."""
|
||||||
|
keys, vals = list(metrics.keys()), list(metrics.values())
|
||||||
|
n = len(metrics) + 1 # number of cols
|
||||||
|
s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
|
||||||
|
with open(self.csv, "a") as f:
|
||||||
|
f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
|
||||||
|
|
||||||
|
def plot_metrics(self):
|
||||||
|
"""Plot and display metrics visually."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_plot(self, name, data=None):
|
||||||
|
"""Registers plots (e.g. to be consumed in callbacks)"""
|
||||||
|
path = Path(name)
|
||||||
|
self.plots[path] = {"data": data, "timestamp": time.time()}
|
||||||
|
|
||||||
|
def final_eval(self):
|
||||||
|
"""Performs final evaluation and validation for object detection YOLO model."""
|
||||||
|
for f in self.last, self.best:
|
||||||
|
if f.exists():
|
||||||
|
strip_optimizer(f) # strip optimizers
|
||||||
|
if f is self.best:
|
||||||
|
LOGGER.info(f"\nValidating {f}...")
|
||||||
|
self.validator.args.plots = self.args.plots
|
||||||
|
self.metrics = self.validator(model=f)
|
||||||
|
self.metrics.pop("fitness", None)
|
||||||
|
self.run_callbacks("on_fit_epoch_end")
|
||||||
|
|
||||||
|
def check_resume(self, overrides):
|
||||||
|
"""Check if resume checkpoint exists and update arguments accordingly."""
|
||||||
|
resume = self.args.resume
|
||||||
|
if resume:
|
||||||
|
try:
|
||||||
|
exists = isinstance(resume, (str, Path)) and Path(resume).exists()
|
||||||
|
last = Path(check_file(resume) if exists else get_latest_run())
|
||||||
|
|
||||||
|
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
|
||||||
|
ckpt_args = attempt_load_weights(last).args
|
||||||
|
if not Path(ckpt_args["data"]).exists():
|
||||||
|
ckpt_args["data"] = self.args.data
|
||||||
|
|
||||||
|
resume = True
|
||||||
|
self.args = get_cfg(ckpt_args)
|
||||||
|
self.args.model = self.args.resume = str(last) # reinstate model
|
||||||
|
for k in "imgsz", "batch", "device": # allow arg updates to reduce memory or update device on resume
|
||||||
|
if k in overrides:
|
||||||
|
setattr(self.args, k, overrides[k])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
|
||||||
|
"i.e. 'yolo train resume model=path/to/last.pt'"
|
||||||
|
) from e
|
||||||
|
self.resume = resume
|
||||||
|
|
||||||
|
def resume_training(self, ckpt):
|
||||||
|
"""Resume YOLO training from given epoch and best fitness."""
|
||||||
|
if ckpt is None or not self.resume:
|
||||||
|
return
|
||||||
|
best_fitness = 0.0
|
||||||
|
start_epoch = ckpt["epoch"] + 1
|
||||||
|
if ckpt["optimizer"] is not None:
|
||||||
|
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
||||||
|
best_fitness = ckpt["best_fitness"]
|
||||||
|
if self.ema and ckpt.get("ema"):
|
||||||
|
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
|
||||||
|
self.ema.updates = ckpt["updates"]
|
||||||
|
assert start_epoch > 0, (
|
||||||
|
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
|
||||||
|
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
||||||
|
)
|
||||||
|
LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
|
||||||
|
if self.epochs < start_epoch:
|
||||||
|
LOGGER.info(
|
||||||
|
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
||||||
|
)
|
||||||
|
self.epochs += ckpt["epoch"] # finetune additional epochs
|
||||||
|
self.best_fitness = best_fitness
|
||||||
|
self.start_epoch = start_epoch
|
||||||
|
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||||
|
self._close_dataloader_mosaic()
|
||||||
|
|
||||||
|
def _close_dataloader_mosaic(self):
|
||||||
|
"""Update dataloaders to stop using mosaic augmentation."""
|
||||||
|
if hasattr(self.train_loader.dataset, "mosaic"):
|
||||||
|
self.train_loader.dataset.mosaic = False
|
||||||
|
if hasattr(self.train_loader.dataset, "close_mosaic"):
|
||||||
|
LOGGER.info("Closing dataloader mosaic")
|
||||||
|
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||||
|
|
||||||
|
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
||||||
|
"""
|
||||||
|
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
|
||||||
|
weight decay, and number of iterations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The model for which to build an optimizer.
|
||||||
|
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
|
||||||
|
based on the number of iterations. Default: 'auto'.
|
||||||
|
lr (float, optional): The learning rate for the optimizer. Default: 0.001.
|
||||||
|
momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
|
||||||
|
decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
|
||||||
|
iterations (float, optional): The number of iterations, which determines the optimizer if
|
||||||
|
name is 'auto'. Default: 1e5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(torch.optim.Optimizer): The constructed optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
g = [], [], [] # optimizer parameter groups
|
||||||
|
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
||||||
|
if name == "auto":
|
||||||
|
LOGGER.info(
|
||||||
|
f"{colorstr('optimizer:')} 'optimizer=auto' found, "
|
||||||
|
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
|
||||||
|
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
|
||||||
|
)
|
||||||
|
nc = getattr(model, "nc", 10) # number of classes
|
||||||
|
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
|
||||||
|
name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
|
||||||
|
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
|
||||||
|
|
||||||
|
for module_name, module in model.named_modules():
|
||||||
|
for param_name, param in module.named_parameters(recurse=False):
|
||||||
|
fullname = f"{module_name}.{param_name}" if module_name else param_name
|
||||||
|
if "bias" in fullname: # bias (no decay)
|
||||||
|
g[2].append(param)
|
||||||
|
elif isinstance(module, bn): # weight (no decay)
|
||||||
|
g[1].append(param)
|
||||||
|
else: # weight (with decay)
|
||||||
|
g[0].append(param)
|
||||||
|
|
||||||
|
if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
|
||||||
|
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
||||||
|
elif name == "RMSProp":
|
||||||
|
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
||||||
|
elif name == "SGD":
|
||||||
|
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Optimizer '{name}' not found in list of available optimizers "
|
||||||
|
f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
|
||||||
|
"To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
|
||||||
|
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
|
||||||
|
LOGGER.info(
|
||||||
|
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
||||||
|
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
|
||||||
|
)
|
||||||
|
return optimizer
|
@ -206,7 +206,7 @@ class BasePredictor:
|
|||||||
self.vid_writer = {}
|
self.vid_writer = {}
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
def stream_inference(self, source=None, model=None, return_images = False, *args, **kwargs):
|
||||||
"""Streams real-time inference on camera feed and saves results to file."""
|
"""Streams real-time inference on camera feed and saves results to file."""
|
||||||
if self.args.verbose:
|
if self.args.verbose:
|
||||||
LOGGER.info("")
|
LOGGER.info("")
|
||||||
@ -243,6 +243,9 @@ class BasePredictor:
|
|||||||
with profilers[0]:
|
with profilers[0]:
|
||||||
im = self.preprocess(im0s)
|
im = self.preprocess(im0s)
|
||||||
|
|
||||||
|
if return_images:
|
||||||
|
im = im.requires_grad_(True)
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
with profilers[1]:
|
with profilers[1]:
|
||||||
preds = self.inference(im, *args, **kwargs)
|
preds = self.inference(im, *args, **kwargs)
|
||||||
@ -272,7 +275,7 @@ class BasePredictor:
|
|||||||
LOGGER.info("\n".join(s))
|
LOGGER.info("\n".join(s))
|
||||||
|
|
||||||
self.run_callbacks("on_predict_batch_end")
|
self.run_callbacks("on_predict_batch_end")
|
||||||
yield from self.results
|
yield from (self.results, im)
|
||||||
|
|
||||||
# Release assets
|
# Release assets
|
||||||
for v in self.vid_writer.values():
|
for v in self.vid_writer.values():
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
from .predict import DetectionPredictor
|
from .predict import DetectionPredictor
|
||||||
|
from .pgt_train import PGTDetectionTrainer
|
||||||
from .train import DetectionTrainer
|
from .train import DetectionTrainer
|
||||||
from .val import DetectionValidator
|
from .val import DetectionValidator
|
||||||
|
|
||||||
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"
|
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer"
|
||||||
|
144
ultralytics/models/yolo/detect/pgt_train.py
Normal file
144
ultralytics/models/yolo/detect/pgt_train.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ultralytics.data import build_dataloader, build_yolo_dataset
|
||||||
|
from ultralytics.engine.trainer import BaseTrainer
|
||||||
|
from ultralytics.engine.pgt_trainer import PGTTrainer
|
||||||
|
from ultralytics.models import yolo
|
||||||
|
from ultralytics.nn.tasks import DetectionModel
|
||||||
|
from ultralytics.utils import LOGGER, RANK
|
||||||
|
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
|
||||||
|
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
|
||||||
|
|
||||||
|
|
||||||
|
class PGTDetectionTrainer(PGTTrainer):
|
||||||
|
"""
|
||||||
|
A class extending the BaseTrainer class for training based on a detection model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from ultralytics.models.yolo.detect import DetectionTrainer
|
||||||
|
|
||||||
|
args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
|
||||||
|
trainer = DetectionTrainer(overrides=args)
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def build_dataset(self, img_path, mode="train", batch=None):
|
||||||
|
"""
|
||||||
|
Build YOLO Dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_path (str): Path to the folder containing images.
|
||||||
|
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||||
|
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||||
|
"""
|
||||||
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||||
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
||||||
|
|
||||||
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||||
|
"""Construct and return dataloader."""
|
||||||
|
assert mode in ["train", "val"]
|
||||||
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||||
|
dataset = self.build_dataset(dataset_path, mode, batch_size)
|
||||||
|
shuffle = mode == "train"
|
||||||
|
if getattr(dataset, "rect", False) and shuffle:
|
||||||
|
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||||
|
shuffle = False
|
||||||
|
workers = self.args.workers if mode == "train" else self.args.workers * 2
|
||||||
|
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
|
||||||
|
|
||||||
|
def preprocess_batch(self, batch):
|
||||||
|
"""Preprocesses a batch of images by scaling and converting to float."""
|
||||||
|
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||||
|
if self.args.multi_scale:
|
||||||
|
imgs = batch["img"]
|
||||||
|
sz = (
|
||||||
|
random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride)
|
||||||
|
// self.stride
|
||||||
|
* self.stride
|
||||||
|
) # size
|
||||||
|
sf = sz / max(imgs.shape[2:]) # scale factor
|
||||||
|
if sf != 1:
|
||||||
|
ns = [
|
||||||
|
math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
|
||||||
|
] # new shape (stretched to gs-multiple)
|
||||||
|
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||||
|
batch["img"] = imgs
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def set_model_attributes(self):
|
||||||
|
"""Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
|
||||||
|
# self.args.box *= 3 / nl # scale to layers
|
||||||
|
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
||||||
|
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||||
|
self.model.nc = self.data["nc"] # attach number of classes to model
|
||||||
|
self.model.names = self.data["names"] # attach class names to model
|
||||||
|
self.model.args = self.args # attach hyperparameters to model
|
||||||
|
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
||||||
|
|
||||||
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
|
"""Return a YOLO detection model."""
|
||||||
|
model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
|
if weights:
|
||||||
|
model.load(weights)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def get_validator(self):
|
||||||
|
"""Returns a DetectionValidator for YOLO model validation."""
|
||||||
|
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
|
||||||
|
return yolo.detect.DetectionValidator(
|
||||||
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||||
|
"""
|
||||||
|
Returns a loss dict with labelled training loss items tensor.
|
||||||
|
|
||||||
|
Not needed for classification but necessary for segmentation & detection
|
||||||
|
"""
|
||||||
|
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||||
|
if loss_items is not None:
|
||||||
|
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
|
||||||
|
return dict(zip(keys, loss_items))
|
||||||
|
else:
|
||||||
|
return keys
|
||||||
|
|
||||||
|
def progress_string(self):
|
||||||
|
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
|
||||||
|
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
||||||
|
"Epoch",
|
||||||
|
"GPU_mem",
|
||||||
|
*self.loss_names,
|
||||||
|
"Instances",
|
||||||
|
"Size",
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot_training_samples(self, batch, ni):
|
||||||
|
"""Plots training samples with their annotations."""
|
||||||
|
plot_images(
|
||||||
|
images=batch["img"],
|
||||||
|
batch_idx=batch["batch_idx"],
|
||||||
|
cls=batch["cls"].squeeze(-1),
|
||||||
|
bboxes=batch["bboxes"],
|
||||||
|
paths=batch["im_file"],
|
||||||
|
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||||
|
on_plot=self.on_plot,
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot_metrics(self):
|
||||||
|
"""Plots metrics from a CSV file."""
|
||||||
|
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
|
||||||
|
|
||||||
|
def plot_training_labels(self):
|
||||||
|
"""Create a labeled training plot of the YOLO model."""
|
||||||
|
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
|
||||||
|
cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
|
||||||
|
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
|
@ -3,6 +3,8 @@ from ultralytics.nn.tasks import YOLOv10DetectionModel
|
|||||||
from .val import YOLOv10DetectionValidator
|
from .val import YOLOv10DetectionValidator
|
||||||
from .predict import YOLOv10DetectionPredictor
|
from .predict import YOLOv10DetectionPredictor
|
||||||
from .train import YOLOv10DetectionTrainer
|
from .train import YOLOv10DetectionTrainer
|
||||||
|
from .pgt_train import YOLOv10PGTDetectionTrainer
|
||||||
|
# from .pgt_trainer import YOLOv10DetectionTrainer
|
||||||
|
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
from .card import card_template_text
|
from .card import card_template_text
|
||||||
@ -30,6 +32,7 @@ class YOLOv10(Model, PyTorchModelHubMixin, model_card_template=card_template_tex
|
|||||||
"detect": {
|
"detect": {
|
||||||
"model": YOLOv10DetectionModel,
|
"model": YOLOv10DetectionModel,
|
||||||
"trainer": YOLOv10DetectionTrainer,
|
"trainer": YOLOv10DetectionTrainer,
|
||||||
|
"pgt_trainer": YOLOv10PGTDetectionTrainer,
|
||||||
"validator": YOLOv10DetectionValidator,
|
"validator": YOLOv10DetectionValidator,
|
||||||
"predictor": YOLOv10DetectionPredictor,
|
"predictor": YOLOv10DetectionPredictor,
|
||||||
},
|
},
|
||||||
|
21
ultralytics/models/yolov10/pgt_train.py
Normal file
21
ultralytics/models/yolov10/pgt_train.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from ultralytics.models.yolo.detect import DetectionTrainer
|
||||||
|
from ultralytics.models.yolo.detect import PGTDetectionTrainer
|
||||||
|
from .val import YOLOv10DetectionValidator
|
||||||
|
from .model import YOLOv10DetectionModel
|
||||||
|
from copy import copy
|
||||||
|
from ultralytics.utils import RANK
|
||||||
|
|
||||||
|
class YOLOv10PGTDetectionTrainer(PGTDetectionTrainer):
|
||||||
|
def get_validator(self):
|
||||||
|
"""Returns a DetectionValidator for YOLO model validation."""
|
||||||
|
self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo",
|
||||||
|
return YOLOv10DetectionValidator(
|
||||||
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
|
"""Return a YOLO detection model."""
|
||||||
|
model = YOLOv10DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
|
if weights:
|
||||||
|
model.load(weights)
|
||||||
|
return model
|
@ -93,7 +93,7 @@ class BaseModel(nn.Module):
|
|||||||
return self.loss(x, *args, **kwargs)
|
return self.loss(x, *args, **kwargs)
|
||||||
return self.predict(x, *args, **kwargs)
|
return self.predict(x, *args, **kwargs)
|
||||||
|
|
||||||
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
def predict(self, x, profile=False, visualize=False, augment=False, embed=None, return_images=False):
|
||||||
"""
|
"""
|
||||||
Perform a forward pass through the network.
|
Perform a forward pass through the network.
|
||||||
|
|
||||||
@ -107,9 +107,12 @@ class BaseModel(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
(torch.Tensor): The last output of the model.
|
(torch.Tensor): The last output of the model.
|
||||||
"""
|
"""
|
||||||
|
if return_images:
|
||||||
|
x = x.requires_grad_(True)
|
||||||
if augment:
|
if augment:
|
||||||
return self._predict_augment(x)
|
return self._predict_augment(x)
|
||||||
return self._predict_once(x, profile, visualize, embed)
|
out = self._predict_once(x, profile, visualize, embed)
|
||||||
|
return (out, x) if return_images else out
|
||||||
|
|
||||||
def _predict_once(self, x, profile=False, visualize=False, embed=None):
|
def _predict_once(self, x, profile=False, visualize=False, embed=None):
|
||||||
"""
|
"""
|
||||||
@ -140,13 +143,13 @@ class BaseModel(nn.Module):
|
|||||||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _predict_augment(self, x):
|
def _predict_augment(self, x, *args, **kwargs):
|
||||||
"""Perform augmentations on input image x and return augmented inference."""
|
"""Perform augmentations on input image x and return augmented inference."""
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. "
|
f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. "
|
||||||
f"Reverting to single-scale inference instead."
|
f"Reverting to single-scale inference instead."
|
||||||
)
|
)
|
||||||
return self._predict_once(x)
|
return self._predict_once(x, *args, **kwargs)
|
||||||
|
|
||||||
def _profile_one_layer(self, m, x, dt):
|
def _profile_one_layer(self, m, x, dt):
|
||||||
"""
|
"""
|
||||||
@ -260,7 +263,7 @@ class BaseModel(nn.Module):
|
|||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")
|
LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")
|
||||||
|
|
||||||
def loss(self, batch, preds=None):
|
def loss(self, batch, preds=None, return_images=False):
|
||||||
"""
|
"""
|
||||||
Compute loss.
|
Compute loss.
|
||||||
|
|
||||||
@ -271,8 +274,12 @@ class BaseModel(nn.Module):
|
|||||||
if not hasattr(self, "criterion"):
|
if not hasattr(self, "criterion"):
|
||||||
self.criterion = self.init_criterion()
|
self.criterion = self.init_criterion()
|
||||||
|
|
||||||
preds = self.forward(batch["img"]) if preds is None else preds
|
preds = self.forward(batch["img"], return_images=return_images) if preds is None else preds
|
||||||
return self.criterion(preds, batch)
|
if return_images:
|
||||||
|
preds, im = preds
|
||||||
|
loss = self.criterion(preds, batch)
|
||||||
|
out = loss if not return_images else (loss, im)
|
||||||
|
return out
|
||||||
|
|
||||||
def init_criterion(self):
|
def init_criterion(self):
|
||||||
"""Initialize the loss criterion for the BaseModel."""
|
"""Initialize the loss criterion for the BaseModel."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user