mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
Model loading using YOLOv10PGT added, and pgt_coeff is now a cfg parameter
This commit is contained in:
parent
943a333fae
commit
9524af3bfe
@ -1,35 +1,43 @@
|
||||
from ultralytics import YOLOv10, YOLO
|
||||
from ultralytics import YOLOv10, YOLO, YOLOv10PGT
|
||||
# from ultralytics.engine.pgt_trainer import PGTTrainer
|
||||
import os
|
||||
from ultralytics.models.yolo.segment import PGTSegmentationTrainer
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
import torch
|
||||
|
||||
# nohup python run_pgt_train.py --device 0 > ./output_logs/gpu0_yolov10_pgt_train.log 2>&1 &
|
||||
|
||||
def main(args):
|
||||
# model = YOLOv10()
|
||||
|
||||
model = YOLOv10PGT('yolov10n.pt')
|
||||
model.train(
|
||||
data=args.data_yaml,
|
||||
epochs=args.epochs,
|
||||
batch=args.batch_size,
|
||||
# amp=False,
|
||||
# pgt_coeff=1.5,
|
||||
# cfg='pgt_train.yaml', # Load and train model with the config file
|
||||
)
|
||||
# If you want to finetune the model with pretrained weights, you could load the
|
||||
# pretrained weights like below
|
||||
# pretrained weights like below
|
||||
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
|
||||
# or
|
||||
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
||||
# model = YOLOv10('yolov10n.pt', task='segment')
|
||||
# model = YOLOv10('yolov10n.pt', task='segment')
|
||||
|
||||
args_dict = dict(
|
||||
model='yolov10n.pt',
|
||||
data=args.data_yaml,
|
||||
epochs=args.epochs, batch=args.batch_size,
|
||||
# pgt_coeff=5.0,
|
||||
# cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
|
||||
)
|
||||
trainer = PGTSegmentationTrainer(overrides=args_dict)
|
||||
trainer.train(
|
||||
# debug=True,
|
||||
# args = dict(pgt_coeff=0.1), # Should add later to config
|
||||
)
|
||||
# args_dict = dict(
|
||||
# model='yolov10n.pt',
|
||||
# data=args.data_yaml,
|
||||
# epochs=args.epochs, batch=args.batch_size,
|
||||
# # pgt_coeff=5.0,
|
||||
# # cfg = 'pgt_train.yaml', # This can be edited for full control of the training process
|
||||
# )
|
||||
# trainer = PGTSegmentationTrainer(overrides=args_dict)
|
||||
# trainer.train(
|
||||
# # debug=True,
|
||||
# # args = dict(pgt_coeff=0.1), # Should add later to config
|
||||
# )
|
||||
|
||||
# Create a directory to save model weights if it doesn't exist
|
||||
model_weights_dir = 'model_weights'
|
||||
@ -40,10 +48,12 @@ def main(args):
|
||||
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
data_yaml_base = os.path.splitext(os.path.basename(args.data_yaml))[0]
|
||||
model_save_path = os.path.join(model_weights_dir, f'yolov10_{data_yaml_base}_trained_{current_time}.pt')
|
||||
trainer.model.save(model_save_path)
|
||||
model.save(model_save_path)
|
||||
# torch.save(trainer.model.state_dict(), model_save_path)
|
||||
|
||||
|
||||
# Evaluate the model on the validation set
|
||||
results = trainer.val(data=args.data_yaml)
|
||||
results = model.val(data=args.data_yaml)
|
||||
|
||||
# Print the evaluation results
|
||||
print(results)
|
||||
|
@ -3,7 +3,7 @@
|
||||
__version__ = "8.1.34"
|
||||
|
||||
from ultralytics.data.explorer.explorer import Explorer
|
||||
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld, YOLOv10
|
||||
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld, YOLOv10, YOLOv10PGT
|
||||
from ultralytics.models.fastsam import FastSAM
|
||||
from ultralytics.models.nas import NAS
|
||||
from ultralytics.utils import ASSETS, SETTINGS as settings
|
||||
@ -23,5 +23,6 @@ __all__ = (
|
||||
"download",
|
||||
"settings",
|
||||
"Explorer",
|
||||
"YOLOv10"
|
||||
"YOLOv10",
|
||||
"YOLOv10PGT",
|
||||
)
|
||||
|
@ -41,6 +41,7 @@ overlap_mask: True # (bool) masks should overlap during training (segment train
|
||||
mask_ratio: 4 # (int) mask downsample ratio (segment train only)
|
||||
# Classification
|
||||
dropout: 0.0 # (float) use dropout regularization (classify train only)
|
||||
pgt_coeff: 2.0 # (float) PGT loss coefficient
|
||||
|
||||
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
||||
val: True # (bool) validate/test during training
|
||||
|
@ -41,6 +41,7 @@ overlap_mask: True # (bool) masks should overlap during training (segment train
|
||||
mask_ratio: 4 # (int) mask downsample ratio (segment train only)
|
||||
# Classification
|
||||
dropout: 0.0 # (float) use dropout regularization (classify train only)
|
||||
pgt_coeff: 1.0 # (float) PGT loss coefficient
|
||||
|
||||
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
||||
val: True # (bool) validate/test during training
|
||||
|
@ -668,9 +668,10 @@ class Model(nn.Module):
|
||||
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
|
||||
return self.metrics
|
||||
|
||||
def train_pgt(
|
||||
def train_pgt( # Currently unused, but should be considered if changes to the train function are made
|
||||
self,
|
||||
trainer=None,
|
||||
debug=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -733,7 +734,10 @@ class Model(nn.Module):
|
||||
pass
|
||||
|
||||
self.trainer.hub_session = self.session # attach optional HUB session
|
||||
self.trainer.train()
|
||||
if debug:
|
||||
self.trainer.train(debug=debug)
|
||||
else:
|
||||
self.trainer.train()
|
||||
# Update model and cfg after training
|
||||
if RANK in (-1, 0):
|
||||
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
||||
|
@ -3,6 +3,6 @@
|
||||
from .rtdetr import RTDETR
|
||||
from .sam import SAM
|
||||
from .yolo import YOLO, YOLOWorld
|
||||
from .yolov10 import YOLOv10
|
||||
from .yolov10 import YOLOv10, YOLOv10PGT
|
||||
|
||||
__all__ = "YOLO", "RTDETR", "SAM", "YOLOWorld", "YOLOv10" # allow simpler import
|
||||
__all__ = "YOLO", "RTDETR", "SAM", "YOLOWorld", "YOLOv10", "YOLOv10PGT" # allow simpler import
|
||||
|
@ -5,9 +5,18 @@ from copy import copy
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import SegmentationModel, DetectionModel
|
||||
from ultralytics.utils import DEFAULT_CFG, RANK
|
||||
# from ultralytics.utils import yaml_load, IterableSimpleNamespace, ROOT
|
||||
from ultralytics.utils.plotting import plot_images, plot_results
|
||||
from ultralytics.models.yolov10.model import YOLOv10DetectionModel, YOLOv10PGTDetectionModel
|
||||
from ultralytics.models.yolov10.val import YOLOv10DetectionValidator, YOLOv10PGTDetectionValidator
|
||||
from ultralytics.models.yolov10.model import YOLOv10PGTDetectionModel
|
||||
from ultralytics.models.yolov10.val import YOLOv10PGTDetectionValidator
|
||||
|
||||
# # Default configuration
|
||||
# DEFAULT_CFG_DICT = yaml_load(ROOT / "cfg/pgt_train.yaml")
|
||||
# for k, v in DEFAULT_CFG_DICT.items():
|
||||
# if isinstance(v, str) and v.lower() == "none":
|
||||
# DEFAULT_CFG_DICT[k] = None
|
||||
# DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
|
||||
# DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
|
||||
|
||||
class PGTSegmentationTrainer(yolo.detect.PGTDetectionTrainer):
|
||||
"""
|
||||
|
@ -1,5 +1,5 @@
|
||||
from .model import YOLOv10
|
||||
from .model import YOLOv10, YOLOv10PGT
|
||||
from .predict import YOLOv10DetectionPredictor
|
||||
from .val import YOLOv10DetectionValidator
|
||||
|
||||
__all__ = "YOLOv10DetectionPredictor", "YOLOv10DetectionValidator", "YOLOv10"
|
||||
__all__ = "YOLOv10DetectionPredictor", "YOLOv10DetectionValidator", "YOLOv10", "YOLOv10PGT"
|
||||
|
@ -4,6 +4,7 @@ from .val import YOLOv10DetectionValidator
|
||||
from .predict import YOLOv10DetectionPredictor
|
||||
from .train import YOLOv10DetectionTrainer
|
||||
from .pgt_train import YOLOv10PGTDetectionTrainer
|
||||
# from ..yolo.segment import PGTSegmentationTrainer
|
||||
# from .pgt_trainer import YOLOv10DetectionTrainer
|
||||
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
@ -36,4 +37,36 @@ class YOLOv10(Model, PyTorchModelHubMixin, model_card_template=card_template_tex
|
||||
"validator": YOLOv10DetectionValidator,
|
||||
"predictor": YOLOv10DetectionPredictor,
|
||||
},
|
||||
}
|
||||
|
||||
def _get_pgt_segmentation_trainer():
|
||||
from ..yolo.segment import PGTSegmentationTrainer
|
||||
return PGTSegmentationTrainer
|
||||
|
||||
class YOLOv10PGT(Model, PyTorchModelHubMixin, model_card_template=card_template_text):
|
||||
|
||||
def __init__(self, model="yolov10n.pt", task=None, verbose=False,
|
||||
names=None):
|
||||
super().__init__(model=model, task=task, verbose=verbose)
|
||||
if names is not None:
|
||||
setattr(self.model, 'names', names)
|
||||
|
||||
def push_to_hub(self, repo_name, **kwargs):
|
||||
config = kwargs.get('config', {})
|
||||
config['names'] = self.names
|
||||
config['model'] = self.model.yaml['yaml_file']
|
||||
config['task'] = self.task
|
||||
kwargs['config'] = config
|
||||
super().push_to_hub(repo_name, **kwargs)
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
"""Map head to model, trainer, validator, and predictor classes."""
|
||||
return {
|
||||
"detect": {
|
||||
"model": YOLOv10DetectionModel,
|
||||
"trainer": _get_pgt_segmentation_trainer(),
|
||||
"validator": YOLOv10DetectionValidator,
|
||||
"predictor": YOLOv10DetectionPredictor,
|
||||
},
|
||||
}
|
@ -654,7 +654,7 @@ class YOLOv10DetectionModel(DetectionModel):
|
||||
|
||||
class YOLOv10PGTDetectionModel(DetectionModel):
|
||||
def init_criterion(self):
|
||||
return v10PGTDetectLoss(self)
|
||||
return v10PGTDetectLoss(self, pgt_coeff=self.args.pgt_coeff if hasattr(self.args, 'pgt_coeff') else None)
|
||||
|
||||
class Ensemble(nn.ModuleList):
|
||||
"""Ensemble of models."""
|
||||
|
@ -727,12 +727,14 @@ class v10DetectLoss:
|
||||
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
|
||||
|
||||
class v10PGTDetectLoss:
|
||||
def __init__(self, model, pgt_coeff=3.0):
|
||||
def __init__(self, model, pgt_coeff):
|
||||
self.one2many = v8DetectionLoss(model, tal_topk=10)
|
||||
self.one2one = v8DetectionLoss(model, tal_topk=1)
|
||||
self.pgt_coeff = pgt_coeff
|
||||
self.pgt_coeff = pgt_coeff if pgt_coeff is not None else 2.0
|
||||
|
||||
def __call__(self, preds, batch, return_plaus=True, inference=False):
|
||||
def __call__(self, preds, batch, return_plaus=True, pgt_coeff=None):
|
||||
if pgt_coeff is not None:
|
||||
self.pgt_coeff = pgt_coeff
|
||||
batch['img'] = batch['img'].requires_grad_(True)
|
||||
one2many = preds["one2many"]
|
||||
loss_one2many = self.one2many(one2many, batch)
|
||||
@ -743,11 +745,6 @@ class v10PGTDetectLoss:
|
||||
if return_plaus:
|
||||
smask = get_dist_reg(batch['img'], batch['masks'])#.requires_grad_(True)
|
||||
|
||||
# graph = False if inference else True
|
||||
# grad = torch.autograd.grad(loss, batch['img'],
|
||||
# retain_graph=True,
|
||||
# create_graph=graph,
|
||||
# )[0]
|
||||
try:
|
||||
grad = torch.autograd.grad(loss, batch['img'],
|
||||
retain_graph=True,
|
||||
@ -764,6 +761,7 @@ class v10PGTDetectLoss:
|
||||
|
||||
plaus_loss = plaus_loss_fn(grad, smask, self.pgt_coeff)
|
||||
# self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
|
||||
|
||||
loss += plaus_loss
|
||||
|
||||
return loss, torch.cat((loss_one2many[1], loss_one2one[1], plaus_loss.unsqueeze(0)))
|
||||
|
@ -39,8 +39,11 @@ def get_dist_reg(images, seg_mask):
|
||||
kernel_size += 1
|
||||
seg_mask1 = T.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma)(seg_mask)
|
||||
if torch.max(seg_mask1) > 1.0:
|
||||
seg_mask1 = (seg_mask1 - seg_mask1.min()) / (seg_mask1.max() - seg_mask1.min())
|
||||
# seg_mask1 = (seg_mask1 - seg_mask1.min()) / (seg_mask1.max() - seg_mask1.min())
|
||||
seg_mask1 = normalize_tensor(seg_mask1)
|
||||
smask = torch.max(smask, seg_mask1)
|
||||
|
||||
smask = normalize_tensor(smask)
|
||||
return smask
|
||||
|
||||
def get_gradient(img, grad_wrt, norm=False, absolute=True, grayscale=False, keepmean=False):
|
||||
@ -374,6 +377,29 @@ def normalize_batch(x):
|
||||
|
||||
return x_
|
||||
|
||||
def normalize_batch_nonan(x):
|
||||
"""
|
||||
Normalize a batch of tensors along each channel.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Normalized tensor of the same shape as the input.
|
||||
"""
|
||||
mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
|
||||
maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
|
||||
x_ = torch.zeros_like(x)
|
||||
for i in range(x.shape[0]):
|
||||
if torch.all(x[i] == 0):
|
||||
x_[i] = x[i]
|
||||
else:
|
||||
mins[i] = x[i].min()
|
||||
maxs[i] = x[i].max()
|
||||
x_[i] = (x[i] - mins[i]) / (maxs[i] - mins[i])
|
||||
|
||||
return x_
|
||||
|
||||
def get_detections(model_clone, img):
|
||||
"""
|
||||
Get detections from a model given an input image and targets.
|
||||
|
Loading…
x
Reference in New Issue
Block a user