mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Compare commits
15 Commits
53e5ba8924
...
37f38bc9a2
Author | SHA1 | Date | |
---|---|---|---|
![]() |
37f38bc9a2 | ||
![]() |
453c6e38a5 | ||
![]() |
1ef6cbcec5 | ||
![]() |
21e660fde9 | ||
![]() |
44c912647e | ||
![]() |
9524af3bfe | ||
![]() |
943a333fae | ||
![]() |
0efbfe7f4d | ||
![]() |
2ccec65edc | ||
![]() |
411157c18a | ||
![]() |
5953d3c9c6 | ||
![]() |
3a449d5a6c | ||
![]() |
38fa59edf2 | ||
![]() |
2a95a652bd | ||
![]() |
b02bf58de6 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -51,6 +51,7 @@ coverage.xml
|
|||||||
.hypothesis/
|
.hypothesis/
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
mlruns/
|
mlruns/
|
||||||
|
figures/
|
||||||
|
|
||||||
# Translations
|
# Translations
|
||||||
*.mo
|
*.mo
|
||||||
|
@ -10,7 +10,7 @@ Please check out our new release on [**YOLOE**](https://github.com/THU-MIG/yoloe
|
|||||||
Comparison of performance, training cost, and inference efficiency between YOLOE (Ours) and YOLO-Worldv2 in terms of open text prompts.
|
Comparison of performance, training cost, and inference efficiency between YOLOE (Ours) and YOLO-Worldv2 in terms of open text prompts.
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
**YOLOE(ye)** is a highly **efficient**, **unified**, and **open** object detection and segmentation model for real-time seeing anything, like human eye, under different prompt mechanisms, like *texts*, *visual inputs*, and *prompt-free paradigm*.
|
**YOLOE(ye)** is a highly **efficient**, **unified**, and **open** object detection and segmentation model for real-time seeing anything, like human eye, under different prompt mechanisms, like *texts*, *visual inputs*, and *prompt-free paradigm*, with **zero inference and transferring overhead** compared with closed-set YOLOs.
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://github.com/THU-MIG/yoloe/blob/main/figures/visualization.svg" width=96%> <br>
|
<img src="https://github.com/THU-MIG/yoloe/blob/main/figures/visualization.svg" width=96%> <br>
|
||||||
|
49
run_attribution.py
Normal file
49
run_attribution.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from ultralytics import YOLOv10, YOLO
|
||||||
|
# from ultralytics.engine.pgt_trainer import PGTTrainer
|
||||||
|
# from ultralytics import BaseTrainer
|
||||||
|
# from ultralytics.engine.trainer import BaseTrainer
|
||||||
|
import os
|
||||||
|
from ultralytics.models.yolo.segment import PGTSegmentationTrainer
|
||||||
|
|
||||||
|
|
||||||
|
# Set CUDA device (only needed for multi-gpu machines)
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
|
||||||
|
|
||||||
|
# model = YOLOv10()
|
||||||
|
# model = YOLO('yolov8n-seg.yaml').load('yolov8n.pt') # build from YAML and transfer weights
|
||||||
|
|
||||||
|
# model = YOLO()
|
||||||
|
# If you want to finetune the model with pretrained weights, you could load the
|
||||||
|
# pretrained weights like below
|
||||||
|
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
|
||||||
|
# or
|
||||||
|
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
||||||
|
model = YOLOv10('yolov10n.pt', task='segment')
|
||||||
|
|
||||||
|
args = dict(model='yolov10n.pt', data='coco128-seg.yaml')
|
||||||
|
trainer = PGTSegmentationTrainer(overrides=args)
|
||||||
|
trainer.train(
|
||||||
|
# debug=True,
|
||||||
|
# args = dict(pgt_coeff=0.1),
|
||||||
|
)
|
||||||
|
|
||||||
|
# model.train(
|
||||||
|
# # data='coco.yaml',
|
||||||
|
# data='coco128-seg.yaml',
|
||||||
|
# trainer=model._smart_load("pgt_trainer"), # This is needed to generate attributions (will be used later to train via PGT)
|
||||||
|
# # Add return_images as input parameter
|
||||||
|
# epochs=500, batch=16, imgsz=640,
|
||||||
|
# debug=True, # If debug = True, the attributions will be saved in the figures folder
|
||||||
|
# # cfg='/home/nielseni6/PythonScripts/yolov10/ultralytics/cfg/models/v8/yolov8-seg.yaml',
|
||||||
|
# # overrides=dict(task="segment"),
|
||||||
|
# )
|
||||||
|
|
||||||
|
# Save the trained model
|
||||||
|
model.save('yolov10_coco_trained.pt')
|
||||||
|
|
||||||
|
# Evaluate the model on the validation set
|
||||||
|
results = model.val(data='coco.yaml')
|
||||||
|
|
||||||
|
# Print the evaluation results
|
||||||
|
print(results)
|
63
run_pgt_train.py
Normal file
63
run_pgt_train.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from ultralytics import YOLOv10, YOLO, YOLOv10PGT
|
||||||
|
# from ultralytics.engine.pgt_trainer import PGTTrainer
|
||||||
|
import os
|
||||||
|
from ultralytics.models.yolo.segment import PGTSegmentationTrainer
|
||||||
|
import argparse
|
||||||
|
from datetime import datetime
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# nohup python run_pgt_train.py --device 7 > ./output_logs/gpu7_yolov10_pgt_train.log 2>&1 &
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
model = YOLOv10PGT('yolov10n.pt')
|
||||||
|
|
||||||
|
if args.pgt_coeff is None:
|
||||||
|
model.train(data=args.data_yaml, epochs=args.epochs, batch=args.batch_size)
|
||||||
|
else:
|
||||||
|
model.train(
|
||||||
|
data=args.data_yaml,
|
||||||
|
epochs=args.epochs,
|
||||||
|
batch=args.batch_size,
|
||||||
|
# amp=False,
|
||||||
|
pgt_coeff=args.pgt_coeff,
|
||||||
|
# cfg='pgt_train.yaml', # Load and train model with the config file
|
||||||
|
)
|
||||||
|
# If you want to finetune the model with pretrained weights, you could load the
|
||||||
|
# pretrained weights like below
|
||||||
|
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
|
||||||
|
# or
|
||||||
|
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
||||||
|
# model = YOLOv10('yolov10n.pt', task='segment')
|
||||||
|
|
||||||
|
# Create a directory to save model weights if it doesn't exist
|
||||||
|
model_weights_dir = 'model_weights'
|
||||||
|
if not os.path.exists(model_weights_dir):
|
||||||
|
os.makedirs(model_weights_dir)
|
||||||
|
|
||||||
|
# Save the trained model with a unique name based on the current date and time
|
||||||
|
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
data_yaml_base = os.path.splitext(os.path.basename(args.data_yaml))[0]
|
||||||
|
model_save_path = os.path.join(model_weights_dir, f'yolov10_{data_yaml_base}_trained_{current_time}.pt')
|
||||||
|
model.save(model_save_path)
|
||||||
|
# torch.save(trainer.model.state_dict(), model_save_path)
|
||||||
|
|
||||||
|
# Evaluate the model on the validation set
|
||||||
|
results = model.val(data=args.data_yaml)
|
||||||
|
|
||||||
|
# Print the evaluation results
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.')
|
||||||
|
parser.add_argument('--device', type=str, default='0', help='CUDA device number')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
|
||||||
|
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
|
||||||
|
parser.add_argument('--data_yaml', type=str, default='coco.yaml', help='Path to the data YAML file')
|
||||||
|
parser.add_argument('--pgt_coeff', type=float, default=None, help='Coefficient for PGT')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set CUDA device (only needed for multi-gpu machines)
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
||||||
|
main(args)
|
||||||
|
|
83
run_train.py
Normal file
83
run_train.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
from ultralytics import YOLOv10, YOLO
|
||||||
|
# from ultralytics.engine.pgt_trainer import PGTTrainer
|
||||||
|
# from ultralytics import BaseTrainer
|
||||||
|
# from ultralytics.engine.trainer import BaseTrainer
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Set CUDA device (only needed for multi-gpu machines)
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
|
||||||
|
|
||||||
|
model = YOLOv10()
|
||||||
|
# model = YOLO()
|
||||||
|
# If you want to finetune the model with pretrained weights, you could load the
|
||||||
|
# pretrained weights like below
|
||||||
|
# model = YOLOv10.from_pretrained('jameslahm/yolov10{n/s/m/b/l/x}')
|
||||||
|
# or
|
||||||
|
# wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10{n/s/m/b/l/x}.pt
|
||||||
|
# model = YOLOv10('yolov10m.pt')
|
||||||
|
|
||||||
|
model.train(data='coco.yaml',
|
||||||
|
# Add return_images as input parameter
|
||||||
|
epochs=500, batch=16, imgsz=640,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the trained model
|
||||||
|
model.save('yolov10_coco_trained.pt')
|
||||||
|
|
||||||
|
# Evaluate the model on the validation set
|
||||||
|
results = model.val(data='coco.yaml')
|
||||||
|
|
||||||
|
# Print the evaluation results
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
# import torch
|
||||||
|
# from torch.utils.data import DataLoader
|
||||||
|
# from torchvision import datasets, transforms
|
||||||
|
|
||||||
|
# # Define the transformation for the dataset
|
||||||
|
# transform = transforms.Compose([
|
||||||
|
# transforms.Resize((640, 640)),
|
||||||
|
# transforms.ToTensor()
|
||||||
|
# ])
|
||||||
|
|
||||||
|
# # Load the COCO dataset
|
||||||
|
# train_dataset = datasets.CocoDetection(root='data/nielseni6/coco/train2017', annFile='/data/nielseni6/coco/annotations/instances_train2017.json', transform=transform)
|
||||||
|
# val_dataset = datasets.CocoDetection(root='data/nielseni6/coco/val2017', annFile='/data/nielseni6/coco/annotations/instances_val2017.json', transform=transform)
|
||||||
|
|
||||||
|
# # Create data loaders
|
||||||
|
# train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
|
||||||
|
# val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)
|
||||||
|
|
||||||
|
# model = YOLOv10()
|
||||||
|
|
||||||
|
# # Define the optimizer
|
||||||
|
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
# # Training loop
|
||||||
|
# for epoch in range(500):
|
||||||
|
# model.train()
|
||||||
|
# for images, targets in train_loader:
|
||||||
|
# images = images.to('cuda')
|
||||||
|
# targets = [{k: v.to('cuda') for k, v in t.items()} for t in targets]
|
||||||
|
# loss = model(images, targets)
|
||||||
|
# loss.backward()
|
||||||
|
# optimizer.step()
|
||||||
|
# optimizer.zero_grad()
|
||||||
|
|
||||||
|
# # Validation loop
|
||||||
|
# model.eval()
|
||||||
|
# with torch.no_grad():
|
||||||
|
# for images, targets in val_loader:
|
||||||
|
# images = images.to('cuda')
|
||||||
|
# targets = [{k: v.to('cuda') for k, v in t.items()} for t in targets]
|
||||||
|
# results = model(images, targets)
|
||||||
|
|
||||||
|
# # Save the trained model
|
||||||
|
# model.save('yolov10_coco_trained.pt')
|
||||||
|
|
||||||
|
# # Evaluate the model on the validation set
|
||||||
|
# results = model.val(data='coco.yaml')
|
||||||
|
|
||||||
|
# # Print the evaluation results
|
||||||
|
# print(results)
|
32
run_val.py
Normal file
32
run_val.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from ultralytics import YOLOv10, YOLO, YOLOv10PGT
|
||||||
|
# from ultralytics.engine.pgt_trainer import PGTTrainer
|
||||||
|
import os
|
||||||
|
from ultralytics.models.yolo.segment import PGTSegmentationTrainer
|
||||||
|
import argparse
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# nohup python run_pgt_train.py --device 1 > ./output_logs/gpu1_yolov10_pgt_train.log 2>&1 &
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
|
||||||
|
model = YOLOv10PGT(args.model_path)
|
||||||
|
|
||||||
|
# Evaluate the model on the validation set
|
||||||
|
results = model.val(data=args.data_yaml)
|
||||||
|
|
||||||
|
# Print the evaluation results
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description='Train YOLOv10 model with PGT segmentation.')
|
||||||
|
parser.add_argument('--device', type=str, default='1', help='CUDA device number')
|
||||||
|
parser.add_argument('--batch_size', type=int, default=64, help='Batch size for training')
|
||||||
|
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for training')
|
||||||
|
parser.add_argument('--data_yaml', type=str, default='coco.yaml', help='Path to the data YAML file')
|
||||||
|
parser.add_argument('--model_path', type=str, default='yolov10n.pt', help='Path to the model file')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set CUDA device (only needed for multi-gpu machines)
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
|
||||||
|
main(args)
|
@ -3,7 +3,7 @@
|
|||||||
__version__ = "8.1.34"
|
__version__ = "8.1.34"
|
||||||
|
|
||||||
from ultralytics.data.explorer.explorer import Explorer
|
from ultralytics.data.explorer.explorer import Explorer
|
||||||
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld, YOLOv10
|
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld, YOLOv10, YOLOv10PGT
|
||||||
from ultralytics.models.fastsam import FastSAM
|
from ultralytics.models.fastsam import FastSAM
|
||||||
from ultralytics.models.nas import NAS
|
from ultralytics.models.nas import NAS
|
||||||
from ultralytics.utils import ASSETS, SETTINGS as settings
|
from ultralytics.utils import ASSETS, SETTINGS as settings
|
||||||
@ -23,5 +23,6 @@ __all__ = (
|
|||||||
"download",
|
"download",
|
||||||
"settings",
|
"settings",
|
||||||
"Explorer",
|
"Explorer",
|
||||||
"YOLOv10"
|
"YOLOv10",
|
||||||
|
"YOLOv10PGT",
|
||||||
)
|
)
|
||||||
|
@ -41,6 +41,7 @@ overlap_mask: True # (bool) masks should overlap during training (segment train
|
|||||||
mask_ratio: 4 # (int) mask downsample ratio (segment train only)
|
mask_ratio: 4 # (int) mask downsample ratio (segment train only)
|
||||||
# Classification
|
# Classification
|
||||||
dropout: 0.0 # (float) use dropout regularization (classify train only)
|
dropout: 0.0 # (float) use dropout regularization (classify train only)
|
||||||
|
pgt_coeff: 2.0 # (float) PGT loss coefficient
|
||||||
|
|
||||||
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
||||||
val: True # (bool) validate/test during training
|
val: True # (bool) validate/test during training
|
||||||
|
128
ultralytics/cfg/pgt_train.yaml
Normal file
128
ultralytics/cfg/pgt_train.yaml
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
# Default training settings and hyperparameters for medium-augmentation COCO training
|
||||||
|
|
||||||
|
task: detect # (str) YOLO task, i.e. detect, segment, classify, pose
|
||||||
|
mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark
|
||||||
|
|
||||||
|
# Train settings -------------------------------------------------------------------------------------------------------
|
||||||
|
model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
|
||||||
|
data: # (str, optional) path to data file, i.e. coco128.yaml
|
||||||
|
epochs: 100 # (int) number of epochs to train for
|
||||||
|
time: # (float, optional) number of hours to train for, overrides epochs if supplied
|
||||||
|
patience: 100 # (int) epochs to wait for no observable improvement for early stopping of training
|
||||||
|
batch: 16 # (int) number of images per batch (-1 for AutoBatch)
|
||||||
|
imgsz: 640 # (int | list) input images size as int for train and val modes, or list[w,h] for predict and export modes
|
||||||
|
save: True # (bool) save train checkpoints and predict results
|
||||||
|
save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1)
|
||||||
|
val_period: 1 # (int) Validation every x epochs
|
||||||
|
cache: False # (bool) True/ram, disk or False. Use cache for data loading
|
||||||
|
device: # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu
|
||||||
|
workers: 8 # (int) number of worker threads for data loading (per RANK if DDP)
|
||||||
|
project: # (str, optional) project name
|
||||||
|
name: # (str, optional) experiment name, results saved to 'project/name' directory
|
||||||
|
exist_ok: False # (bool) whether to overwrite existing experiment
|
||||||
|
pretrained: True # (bool | str) whether to use a pretrained model (bool) or a model to load weights from (str)
|
||||||
|
optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto]
|
||||||
|
verbose: True # (bool) whether to print verbose output
|
||||||
|
seed: 0 # (int) random seed for reproducibility
|
||||||
|
deterministic: True # (bool) whether to enable deterministic mode
|
||||||
|
single_cls: False # (bool) train multi-class data as single-class
|
||||||
|
rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val'
|
||||||
|
cos_lr: False # (bool) use cosine learning rate scheduler
|
||||||
|
close_mosaic: 10 # (int) disable mosaic augmentation for final epochs (0 to disable)
|
||||||
|
resume: False # (bool) resume training from last checkpoint
|
||||||
|
amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
|
||||||
|
fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set)
|
||||||
|
profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers
|
||||||
|
freeze: None # (int | list, optional) freeze first n layers, or freeze list of layer indices during training
|
||||||
|
multi_scale: False # (bool) Whether to use multiscale during training
|
||||||
|
# Segmentation
|
||||||
|
overlap_mask: True # (bool) masks should overlap during training (segment train only)
|
||||||
|
mask_ratio: 4 # (int) mask downsample ratio (segment train only)
|
||||||
|
# Classification
|
||||||
|
dropout: 0.0 # (float) use dropout regularization (classify train only)
|
||||||
|
pgt_coeff: 1.0 # (float) PGT loss coefficient
|
||||||
|
|
||||||
|
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
||||||
|
val: True # (bool) validate/test during training
|
||||||
|
split: val # (str) dataset split to use for validation, i.e. 'val', 'test' or 'train'
|
||||||
|
save_json: False # (bool) save results to JSON file
|
||||||
|
save_hybrid: False # (bool) save hybrid version of labels (labels + additional predictions)
|
||||||
|
conf: # (float, optional) object confidence threshold for detection (default 0.25 predict, 0.001 val)
|
||||||
|
iou: 0.7 # (float) intersection over union (IoU) threshold for NMS
|
||||||
|
max_det: 300 # (int) maximum number of detections per image
|
||||||
|
half: False # (bool) use half precision (FP16)
|
||||||
|
dnn: False # (bool) use OpenCV DNN for ONNX inference
|
||||||
|
plots: True # (bool) save plots and images during train/val
|
||||||
|
|
||||||
|
# Predict settings -----------------------------------------------------------------------------------------------------
|
||||||
|
source: # (str, optional) source directory for images or videos
|
||||||
|
vid_stride: 1 # (int) video frame-rate stride
|
||||||
|
stream_buffer: False # (bool) buffer all streaming frames (True) or return the most recent frame (False)
|
||||||
|
visualize: False # (bool) visualize model features
|
||||||
|
augment: False # (bool) apply image augmentation to prediction sources
|
||||||
|
agnostic_nms: False # (bool) class-agnostic NMS
|
||||||
|
classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3]
|
||||||
|
retina_masks: False # (bool) use high-resolution segmentation masks
|
||||||
|
embed: # (list[int], optional) return feature vectors/embeddings from given layers
|
||||||
|
|
||||||
|
# Visualize settings ---------------------------------------------------------------------------------------------------
|
||||||
|
show: False # (bool) show predicted images and videos if environment allows
|
||||||
|
save_frames: False # (bool) save predicted individual video frames
|
||||||
|
save_txt: False # (bool) save results as .txt file
|
||||||
|
save_conf: False # (bool) save results with confidence scores
|
||||||
|
save_crop: False # (bool) save cropped images with results
|
||||||
|
show_labels: True # (bool) show prediction labels, i.e. 'person'
|
||||||
|
show_conf: True # (bool) show prediction confidence, i.e. '0.99'
|
||||||
|
show_boxes: True # (bool) show prediction boxes
|
||||||
|
line_width: # (int, optional) line width of the bounding boxes. Scaled to image size if None.
|
||||||
|
|
||||||
|
# Export settings ------------------------------------------------------------------------------------------------------
|
||||||
|
format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats
|
||||||
|
keras: False # (bool) use Kera=s
|
||||||
|
optimize: False # (bool) TorchScript: optimize for mobile
|
||||||
|
int8: False # (bool) CoreML/TF INT8 quantization
|
||||||
|
dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes
|
||||||
|
simplify: False # (bool) ONNX: simplify model using `onnxslim`
|
||||||
|
opset: # (int, optional) ONNX: opset version
|
||||||
|
workspace: 4 # (int) TensorRT: workspace size (GB)
|
||||||
|
nms: False # (bool) CoreML: add NMS
|
||||||
|
|
||||||
|
# Hyperparameters ------------------------------------------------------------------------------------------------------
|
||||||
|
lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
|
||||||
|
lrf: 0.01 # (float) final learning rate (lr0 * lrf)
|
||||||
|
momentum: 0.937 # (float) SGD momentum/Adam beta1
|
||||||
|
weight_decay: 0.0005 # (float) optimizer weight decay 5e-4
|
||||||
|
warmup_epochs: 3.0 # (float) warmup epochs (fractions ok)
|
||||||
|
warmup_momentum: 0.8 # (float) warmup initial momentum
|
||||||
|
warmup_bias_lr: 0.1 # (float) warmup initial bias lr
|
||||||
|
box: 7.5 # (float) box loss gain
|
||||||
|
cls: 0.5 # (float) cls loss gain (scale with pixels)
|
||||||
|
dfl: 1.5 # (float) dfl loss gain
|
||||||
|
pose: 12.0 # (float) pose loss gain
|
||||||
|
kobj: 1.0 # (float) keypoint obj loss gain
|
||||||
|
label_smoothing: 0.0 # (float) label smoothing (fraction)
|
||||||
|
nbs: 64 # (int) nominal batch size
|
||||||
|
hsv_h: 0.015 # (float) image HSV-Hue augmentation (fraction)
|
||||||
|
hsv_s: 0.7 # (float) image HSV-Saturation augmentation (fraction)
|
||||||
|
hsv_v: 0.4 # (float) image HSV-Value augmentation (fraction)
|
||||||
|
degrees: 0.0 # (float) image rotation (+/- deg)
|
||||||
|
translate: 0.1 # (float) image translation (+/- fraction)
|
||||||
|
scale: 0.5 # (float) image scale (+/- gain)
|
||||||
|
shear: 0.0 # (float) image shear (+/- deg)
|
||||||
|
perspective: 0.0 # (float) image perspective (+/- fraction), range 0-0.001
|
||||||
|
flipud: 0.0 # (float) image flip up-down (probability)
|
||||||
|
fliplr: 0.5 # (float) image flip left-right (probability)
|
||||||
|
bgr: 0.0 # (float) image channel BGR (probability)
|
||||||
|
mosaic: 1.0 # (float) image mosaic (probability)
|
||||||
|
mixup: 0.0 # (float) image mixup (probability)
|
||||||
|
copy_paste: 0.0 # (float) segment copy-paste (probability)
|
||||||
|
auto_augment: randaugment # (str) auto augmentation policy for classification (randaugment, autoaugment, augmix)
|
||||||
|
erasing: 0.4 # (float) probability of random erasing during classification training (0-1)
|
||||||
|
crop_fraction: 1.0 # (float) image crop fraction for classification evaluation/inference (0-1)
|
||||||
|
|
||||||
|
# Custom config.yaml ---------------------------------------------------------------------------------------------------
|
||||||
|
cfg: # (str, optional) for overriding defaults.yaml
|
||||||
|
|
||||||
|
# Tracker settings ------------------------------------------------------------------------------------------------------
|
||||||
|
tracker: botsort.yaml # (str) tracker type, choices=[botsort.yaml, bytetrack.yaml]
|
@ -387,6 +387,7 @@ class Model(nn.Module):
|
|||||||
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
|
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
predictor=None,
|
predictor=None,
|
||||||
|
return_images: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list:
|
) -> list:
|
||||||
"""
|
"""
|
||||||
@ -438,7 +439,7 @@ class Model(nn.Module):
|
|||||||
self.predictor.save_dir = get_save_dir(self.predictor.args)
|
self.predictor.save_dir = get_save_dir(self.predictor.args)
|
||||||
if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models
|
if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models
|
||||||
self.predictor.set_prompts(prompts)
|
self.predictor.set_prompts(prompts)
|
||||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream, return_images=return_images)
|
||||||
|
|
||||||
def track(
|
def track(
|
||||||
self,
|
self,
|
||||||
@ -592,6 +593,7 @@ class Model(nn.Module):
|
|||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
trainer=None,
|
trainer=None,
|
||||||
|
debug=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -654,6 +656,87 @@ class Model(nn.Module):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
self.trainer.hub_session = self.session # attach optional HUB session
|
self.trainer.hub_session = self.session # attach optional HUB session
|
||||||
|
if debug:
|
||||||
|
self.trainer.train(debug=debug)
|
||||||
|
else:
|
||||||
|
self.trainer.train()
|
||||||
|
# Update model and cfg after training
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
||||||
|
self.model, _ = attempt_load_one_weight(ckpt)
|
||||||
|
self.overrides = self.model.args
|
||||||
|
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
|
||||||
|
return self.metrics
|
||||||
|
|
||||||
|
def train_pgt( # Currently unused, but should be considered if changes to the train function are made
|
||||||
|
self,
|
||||||
|
trainer=None,
|
||||||
|
debug=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Trains the model using the specified dataset and training configuration.
|
||||||
|
|
||||||
|
This method facilitates model training with a range of customizable settings and configurations. It supports
|
||||||
|
training with a custom trainer or the default training approach defined in the method. The method handles
|
||||||
|
different scenarios, such as resuming training from a checkpoint, integrating with Ultralytics HUB, and
|
||||||
|
updating model and configuration after training.
|
||||||
|
|
||||||
|
When using Ultralytics HUB, if the session already has a loaded model, the method prioritizes HUB training
|
||||||
|
arguments and issues a warning if local arguments are provided. It checks for pip updates and combines default
|
||||||
|
configurations, method-specific defaults, and user-provided arguments to configure the training process. After
|
||||||
|
training, it updates the model and its configurations, and optionally attaches metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trainer (BaseTrainer, optional): An instance of a custom trainer class for training the model. If None, the
|
||||||
|
method uses a default trainer. Defaults to None.
|
||||||
|
**kwargs (any): Arbitrary keyword arguments representing the training configuration. These arguments are
|
||||||
|
used to customize various aspects of the training process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(dict | None): Training metrics if available and training is successful; otherwise, None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the model is not a PyTorch model.
|
||||||
|
PermissionError: If there is a permission issue with the HUB session.
|
||||||
|
ModuleNotFoundError: If the HUB SDK is not installed.
|
||||||
|
"""
|
||||||
|
self._check_is_pytorch_model()
|
||||||
|
if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
|
||||||
|
if any(kwargs):
|
||||||
|
LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.")
|
||||||
|
kwargs = self.session.train_args # overwrite kwargs
|
||||||
|
|
||||||
|
checks.check_pip_update_available()
|
||||||
|
|
||||||
|
overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
|
||||||
|
custom = {"data": DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task]} # method defaults
|
||||||
|
args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
|
||||||
|
if args.get("resume"):
|
||||||
|
args["resume"] = self.ckpt_path
|
||||||
|
|
||||||
|
self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
|
||||||
|
if not args.get("resume"): # manually set model only if not resuming
|
||||||
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||||
|
self.model = self.trainer.model
|
||||||
|
|
||||||
|
if SETTINGS["hub"] is True and not self.session:
|
||||||
|
# Create a model in HUB
|
||||||
|
try:
|
||||||
|
self.session = self._get_hub_session(self.model_name)
|
||||||
|
if self.session:
|
||||||
|
self.session.create_model(args)
|
||||||
|
# Check model was created
|
||||||
|
if not getattr(self.session.model, "id", None):
|
||||||
|
self.session = None
|
||||||
|
except (PermissionError, ModuleNotFoundError):
|
||||||
|
# Ignore PermissionError and ModuleNotFoundError which indicates hub-sdk not installed
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.trainer.hub_session = self.session # attach optional HUB session
|
||||||
|
if debug:
|
||||||
|
self.trainer.train(debug=debug)
|
||||||
|
else:
|
||||||
self.trainer.train()
|
self.trainer.train()
|
||||||
# Update model and cfg after training
|
# Update model and cfg after training
|
||||||
if RANK in (-1, 0):
|
if RANK in (-1, 0):
|
||||||
|
850
ultralytics/engine/pgt_trainer.py
Normal file
850
ultralytics/engine/pgt_trainer.py
Normal file
@ -0,0 +1,850 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
"""
|
||||||
|
Train a model on a dataset.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from copy import deepcopy
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import distributed as dist
|
||||||
|
from torch import nn, optim
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torchvision.transforms as T
|
||||||
|
|
||||||
|
from ultralytics.cfg import get_cfg, get_save_dir
|
||||||
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||||
|
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||||
|
from ultralytics.utils import (
|
||||||
|
DEFAULT_CFG,
|
||||||
|
LOGGER,
|
||||||
|
RANK,
|
||||||
|
TQDM,
|
||||||
|
__version__,
|
||||||
|
callbacks,
|
||||||
|
clean_url,
|
||||||
|
colorstr,
|
||||||
|
emojis,
|
||||||
|
yaml_save,
|
||||||
|
)
|
||||||
|
from ultralytics.utils.autobatch import check_train_batch_size
|
||||||
|
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
|
||||||
|
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
||||||
|
from ultralytics.utils.files import get_latest_run
|
||||||
|
from ultralytics.utils.torch_utils import (
|
||||||
|
EarlyStopping,
|
||||||
|
ModelEMA,
|
||||||
|
de_parallel,
|
||||||
|
init_seeds,
|
||||||
|
one_cycle,
|
||||||
|
select_device,
|
||||||
|
strip_optimizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ultralytics.utils.loss import v8DetectionLoss
|
||||||
|
from ultralytics.utils.plaus_functs import get_dist_reg, plaus_loss_fn
|
||||||
|
import matplotlib.path as matplotlib_path
|
||||||
|
class PGTTrainer:
|
||||||
|
"""
|
||||||
|
BaseTrainer.
|
||||||
|
|
||||||
|
A base class for creating trainers.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
args (SimpleNamespace): Configuration for the trainer.
|
||||||
|
validator (BaseValidator): Validator instance.
|
||||||
|
model (nn.Module): Model instance.
|
||||||
|
callbacks (defaultdict): Dictionary of callbacks.
|
||||||
|
save_dir (Path): Directory to save results.
|
||||||
|
wdir (Path): Directory to save weights.
|
||||||
|
last (Path): Path to the last checkpoint.
|
||||||
|
best (Path): Path to the best checkpoint.
|
||||||
|
save_period (int): Save checkpoint every x epochs (disabled if < 1).
|
||||||
|
batch_size (int): Batch size for training.
|
||||||
|
epochs (int): Number of epochs to train for.
|
||||||
|
start_epoch (int): Starting epoch for training.
|
||||||
|
device (torch.device): Device to use for training.
|
||||||
|
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
|
||||||
|
scaler (amp.GradScaler): Gradient scaler for AMP.
|
||||||
|
data (str): Path to data.
|
||||||
|
trainset (torch.utils.data.Dataset): Training dataset.
|
||||||
|
testset (torch.utils.data.Dataset): Testing dataset.
|
||||||
|
ema (nn.Module): EMA (Exponential Moving Average) of the model.
|
||||||
|
resume (bool): Resume training from a checkpoint.
|
||||||
|
lf (nn.Module): Loss function.
|
||||||
|
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
|
||||||
|
best_fitness (float): The best fitness value achieved.
|
||||||
|
fitness (float): Current fitness value.
|
||||||
|
loss (float): Current loss value.
|
||||||
|
tloss (float): Total loss value.
|
||||||
|
loss_names (list): List of loss names.
|
||||||
|
csv (Path): Path to results CSV file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
|
"""
|
||||||
|
Initializes the BaseTrainer class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||||
|
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||||
|
"""
|
||||||
|
self.args = get_cfg(cfg, overrides)
|
||||||
|
self.check_resume(overrides)
|
||||||
|
self.device = select_device(self.args.device, self.args.batch)
|
||||||
|
self.validator = None
|
||||||
|
self.metrics = None
|
||||||
|
self.plots = {}
|
||||||
|
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||||
|
|
||||||
|
# Dirs
|
||||||
|
self.save_dir = get_save_dir(self.args)
|
||||||
|
self.args.name = self.save_dir.name # update name for loggers
|
||||||
|
self.wdir = self.save_dir / "weights" # weights dir
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||||
|
self.args.save_dir = str(self.save_dir)
|
||||||
|
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
|
||||||
|
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
|
||||||
|
self.save_period = self.args.save_period
|
||||||
|
|
||||||
|
self.batch_size = self.args.batch
|
||||||
|
self.epochs = self.args.epochs
|
||||||
|
self.start_epoch = 0
|
||||||
|
if RANK == -1:
|
||||||
|
print_args(vars(self.args))
|
||||||
|
|
||||||
|
# Device
|
||||||
|
if self.device.type in ("cpu", "mps"):
|
||||||
|
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||||
|
|
||||||
|
# Model and Dataset
|
||||||
|
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||||
|
try:
|
||||||
|
if self.args.task == "classify":
|
||||||
|
self.data = check_cls_dataset(self.args.data)
|
||||||
|
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
||||||
|
"detect",
|
||||||
|
"segment",
|
||||||
|
"pose",
|
||||||
|
"obb",
|
||||||
|
):
|
||||||
|
self.data = check_det_dataset(self.args.data)
|
||||||
|
if "yaml_file" in self.data:
|
||||||
|
self.args.data = self.data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
||||||
|
|
||||||
|
self.trainset, self.testset = self.get_dataset(self.data)
|
||||||
|
self.ema = None
|
||||||
|
|
||||||
|
# Optimization utils init
|
||||||
|
self.lf = None
|
||||||
|
self.scheduler = None
|
||||||
|
|
||||||
|
# Epoch level metrics
|
||||||
|
self.best_fitness = None
|
||||||
|
self.fitness = None
|
||||||
|
self.loss = None
|
||||||
|
self.tloss = None
|
||||||
|
self.loss_names = ["Loss"]
|
||||||
|
self.csv = self.save_dir / "results.csv"
|
||||||
|
self.plot_idx = [0, 1, 2]
|
||||||
|
self.num = int(time.time())
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
|
def add_callback(self, event: str, callback):
|
||||||
|
"""Appends the given callback."""
|
||||||
|
self.callbacks[event].append(callback)
|
||||||
|
|
||||||
|
def set_callback(self, event: str, callback):
|
||||||
|
"""Overrides the existing callbacks with the given callback."""
|
||||||
|
self.callbacks[event] = [callback]
|
||||||
|
|
||||||
|
def run_callbacks(self, event: str):
|
||||||
|
"""Run all existing callbacks associated with a particular event."""
|
||||||
|
for callback in self.callbacks.get(event, []):
|
||||||
|
callback(self)
|
||||||
|
|
||||||
|
def train(self, debug=False):
|
||||||
|
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
||||||
|
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
||||||
|
world_size = len(self.args.device.split(","))
|
||||||
|
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
||||||
|
world_size = len(self.args.device)
|
||||||
|
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
||||||
|
world_size = 1 # default to device 0
|
||||||
|
else: # i.e. device='cpu' or 'mps'
|
||||||
|
world_size = 0
|
||||||
|
|
||||||
|
# Run subprocess if DDP training, else train normally
|
||||||
|
if world_size > 1 and "LOCAL_RANK" not in os.environ:
|
||||||
|
# Argument checks
|
||||||
|
if self.args.rect:
|
||||||
|
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
|
||||||
|
self.args.rect = False
|
||||||
|
if self.args.batch == -1:
|
||||||
|
LOGGER.warning(
|
||||||
|
"WARNING ⚠️ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting "
|
||||||
|
"default 'batch=16'"
|
||||||
|
)
|
||||||
|
self.args.batch = 16
|
||||||
|
|
||||||
|
# Command
|
||||||
|
cmd, file = generate_ddp_command(world_size, self)
|
||||||
|
try:
|
||||||
|
LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}')
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
ddp_cleanup(self, str(file))
|
||||||
|
|
||||||
|
else:
|
||||||
|
self._do_train(world_size, debug=debug)
|
||||||
|
|
||||||
|
def _setup_scheduler(self):
|
||||||
|
"""Initialize training learning rate scheduler."""
|
||||||
|
if self.args.cos_lr:
|
||||||
|
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
||||||
|
else:
|
||||||
|
self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||||
|
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||||
|
|
||||||
|
def _setup_ddp(self, world_size):
|
||||||
|
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
||||||
|
torch.cuda.set_device(RANK)
|
||||||
|
self.device = torch.device("cuda", RANK)
|
||||||
|
# LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||||
|
os.environ["NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="nccl" if dist.is_nccl_available() else "gloo",
|
||||||
|
timeout=timedelta(seconds=10800), # 3 hours
|
||||||
|
rank=RANK,
|
||||||
|
world_size=world_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _setup_train(self, world_size):
|
||||||
|
"""Builds dataloaders and optimizer on correct rank process."""
|
||||||
|
|
||||||
|
# Model
|
||||||
|
self.run_callbacks("on_pretrain_routine_start")
|
||||||
|
ckpt = self.setup_model()
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
self.set_model_attributes()
|
||||||
|
|
||||||
|
# Freeze layers
|
||||||
|
freeze_list = (
|
||||||
|
self.args.freeze
|
||||||
|
if isinstance(self.args.freeze, list)
|
||||||
|
else range(self.args.freeze)
|
||||||
|
if isinstance(self.args.freeze, int)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
always_freeze_names = [".dfl"] # always freeze these layers
|
||||||
|
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
|
||||||
|
for k, v in self.model.named_parameters():
|
||||||
|
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
|
||||||
|
if any(x in k for x in freeze_layer_names):
|
||||||
|
LOGGER.info(f"Freezing layer '{k}'")
|
||||||
|
v.requires_grad = False
|
||||||
|
elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients
|
||||||
|
LOGGER.info(
|
||||||
|
f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. "
|
||||||
|
"See ultralytics.engine.trainer for customization of frozen layers."
|
||||||
|
)
|
||||||
|
v.requires_grad = True
|
||||||
|
|
||||||
|
# Check AMP
|
||||||
|
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
||||||
|
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
|
||||||
|
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
||||||
|
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
||||||
|
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
||||||
|
if RANK > -1 and world_size > 1: # DDP
|
||||||
|
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
||||||
|
self.amp = bool(self.amp) # as boolean
|
||||||
|
self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
|
||||||
|
if world_size > 1:
|
||||||
|
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
|
||||||
|
|
||||||
|
# Check imgsz
|
||||||
|
gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
|
||||||
|
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
|
||||||
|
self.stride = gs # for multiscale training
|
||||||
|
|
||||||
|
# Batch size
|
||||||
|
if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size
|
||||||
|
self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
batch_size = self.batch_size // max(world_size, 1)
|
||||||
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
||||||
|
self.test_loader = self.get_dataloader(
|
||||||
|
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
|
||||||
|
)
|
||||||
|
self.validator = self.get_validator()
|
||||||
|
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
|
||||||
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
|
||||||
|
self.ema = ModelEMA(self.model)
|
||||||
|
if self.args.plots:
|
||||||
|
self.plot_training_labels()
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||||
|
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||||
|
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
|
||||||
|
self.optimizer = self.build_optimizer(
|
||||||
|
model=self.model,
|
||||||
|
name=self.args.optimizer,
|
||||||
|
lr=self.args.lr0,
|
||||||
|
momentum=self.args.momentum,
|
||||||
|
decay=weight_decay,
|
||||||
|
iterations=iterations,
|
||||||
|
)
|
||||||
|
# Scheduler
|
||||||
|
self._setup_scheduler()
|
||||||
|
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
||||||
|
self.resume_training(ckpt)
|
||||||
|
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||||
|
self.run_callbacks("on_pretrain_routine_end")
|
||||||
|
|
||||||
|
def _do_train(self, world_size=1, debug=False):
|
||||||
|
"""Train completed, evaluate and plot if specified by arguments."""
|
||||||
|
if world_size > 1:
|
||||||
|
self._setup_ddp(world_size)
|
||||||
|
self._setup_train(world_size)
|
||||||
|
|
||||||
|
nb = len(self.train_loader) # number of batches
|
||||||
|
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
|
||||||
|
last_opt_step = -1
|
||||||
|
self.epoch_time = None
|
||||||
|
self.epoch_time_start = time.time()
|
||||||
|
self.train_time_start = time.time()
|
||||||
|
self.run_callbacks("on_train_start")
|
||||||
|
LOGGER.info(
|
||||||
|
f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||||
|
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||||
|
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||||
|
f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
|
||||||
|
)
|
||||||
|
if self.args.close_mosaic:
|
||||||
|
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||||
|
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
||||||
|
epoch = self.start_epoch
|
||||||
|
while True:
|
||||||
|
self.epoch = epoch
|
||||||
|
self.run_callbacks("on_train_epoch_start")
|
||||||
|
self.model.train()
|
||||||
|
if RANK != -1:
|
||||||
|
self.train_loader.sampler.set_epoch(epoch)
|
||||||
|
pbar = enumerate(self.train_loader)
|
||||||
|
# Update dataloader attributes (optional)
|
||||||
|
if epoch == (self.epochs - self.args.close_mosaic):
|
||||||
|
self._close_dataloader_mosaic()
|
||||||
|
self.train_loader.reset()
|
||||||
|
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
LOGGER.info(self.progress_string())
|
||||||
|
pbar = TQDM(enumerate(self.train_loader), total=nb)
|
||||||
|
self.tloss = None
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
for i, batch in pbar:
|
||||||
|
self.run_callbacks("on_train_batch_start")
|
||||||
|
# Warmup
|
||||||
|
ni = i + nb * epoch
|
||||||
|
if ni <= nw:
|
||||||
|
xi = [0, nw] # x interp
|
||||||
|
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
|
||||||
|
for j, x in enumerate(self.optimizer.param_groups):
|
||||||
|
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
||||||
|
x["lr"] = np.interp(
|
||||||
|
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
|
||||||
|
)
|
||||||
|
if "momentum" in x:
|
||||||
|
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
||||||
|
|
||||||
|
# Forward
|
||||||
|
with torch.cuda.amp.autocast(self.amp):
|
||||||
|
batch = self.preprocess_batch(batch)
|
||||||
|
batch['img'] = batch['img'].requires_grad_(True)
|
||||||
|
self.loss, self.loss_items = self.model(batch)
|
||||||
|
# (self.loss, self.loss_items), images = self.model(batch, return_images=True)
|
||||||
|
|
||||||
|
# smask = get_dist_reg(images, batch['masks'])
|
||||||
|
|
||||||
|
# grad = torch.autograd.grad(self.loss, images, retain_graph=True)[0]
|
||||||
|
# grad = torch.abs(grad)
|
||||||
|
|
||||||
|
# self.args.pgt_coeff = 1.1
|
||||||
|
# plaus_loss = plaus_loss_fn(grad, smask, self.args.pgt_coeff)
|
||||||
|
# self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
|
||||||
|
# self.loss += plaus_loss
|
||||||
|
|
||||||
|
debug_ = debug
|
||||||
|
if debug_ and (i % 25 == 0):
|
||||||
|
debug_ = False
|
||||||
|
|
||||||
|
plot_grads(batch, self, i)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if RANK != -1:
|
||||||
|
self.loss *= world_size
|
||||||
|
self.tloss = (
|
||||||
|
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backward
|
||||||
|
self.scaler.scale(self.loss).backward()
|
||||||
|
|
||||||
|
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
|
||||||
|
if ni - last_opt_step >= self.accumulate:
|
||||||
|
self.optimizer_step()
|
||||||
|
last_opt_step = ni
|
||||||
|
|
||||||
|
# Timed stopping
|
||||||
|
if self.args.time:
|
||||||
|
self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
|
||||||
|
if RANK != -1: # if DDP training
|
||||||
|
broadcast_list = [self.stop if RANK == 0 else None]
|
||||||
|
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
||||||
|
self.stop = broadcast_list[0]
|
||||||
|
if self.stop: # training time exceeded
|
||||||
|
break
|
||||||
|
|
||||||
|
# Log
|
||||||
|
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
|
||||||
|
loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
|
||||||
|
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
pbar.set_description(
|
||||||
|
("%11s" * 2 + "%11.4g" * (2 + loss_len))
|
||||||
|
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
|
||||||
|
)
|
||||||
|
self.run_callbacks("on_batch_end")
|
||||||
|
if self.args.plots and ni in self.plot_idx:
|
||||||
|
self.plot_training_samples(batch, ni)
|
||||||
|
plot_grads(batch, self, ni)
|
||||||
|
|
||||||
|
self.run_callbacks("on_train_batch_end")
|
||||||
|
|
||||||
|
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||||
|
self.run_callbacks("on_train_epoch_end")
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
final_epoch = epoch + 1 == self.epochs
|
||||||
|
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
if (self.args.val and (((epoch+1) % self.args.val_period == 0) or (self.epochs - epoch) <= 10)) \
|
||||||
|
or final_epoch or self.stopper.possible_stop or self.stop:
|
||||||
|
self.metrics, self.fitness = self.validate()
|
||||||
|
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
|
||||||
|
self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
|
||||||
|
if self.args.time:
|
||||||
|
self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
|
||||||
|
|
||||||
|
# Save model
|
||||||
|
if self.args.save or final_epoch:
|
||||||
|
self.save_model()
|
||||||
|
self.run_callbacks("on_model_save")
|
||||||
|
|
||||||
|
# Scheduler
|
||||||
|
t = time.time()
|
||||||
|
self.epoch_time = t - self.epoch_time_start
|
||||||
|
self.epoch_time_start = t
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
||||||
|
if self.args.time:
|
||||||
|
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
|
||||||
|
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
|
||||||
|
self._setup_scheduler()
|
||||||
|
self.scheduler.last_epoch = self.epoch # do not move
|
||||||
|
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
||||||
|
self.scheduler.step()
|
||||||
|
self.run_callbacks("on_fit_epoch_end")
|
||||||
|
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
|
||||||
|
|
||||||
|
# Early Stopping
|
||||||
|
if RANK != -1: # if DDP training
|
||||||
|
broadcast_list = [self.stop if RANK == 0 else None]
|
||||||
|
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
||||||
|
self.stop = broadcast_list[0]
|
||||||
|
if self.stop:
|
||||||
|
break # must break all DDP ranks
|
||||||
|
epoch += 1
|
||||||
|
|
||||||
|
if RANK in (-1, 0):
|
||||||
|
# Do final val with best.pt
|
||||||
|
LOGGER.info(
|
||||||
|
f"\n{epoch - self.start_epoch + 1} epochs completed in "
|
||||||
|
f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
|
||||||
|
)
|
||||||
|
self.final_eval()
|
||||||
|
if self.args.plots:
|
||||||
|
self.plot_metrics()
|
||||||
|
self.run_callbacks("on_train_end")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
self.run_callbacks("teardown")
|
||||||
|
|
||||||
|
def save_model(self):
|
||||||
|
"""Save model training checkpoints with additional metadata."""
|
||||||
|
import pandas as pd # scope for faster startup
|
||||||
|
|
||||||
|
metrics = {**self.metrics, **{"fitness": self.fitness}}
|
||||||
|
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
|
||||||
|
ckpt = {
|
||||||
|
"epoch": self.epoch,
|
||||||
|
"best_fitness": self.best_fitness,
|
||||||
|
"model": deepcopy(de_parallel(self.model)).half(),
|
||||||
|
"ema": deepcopy(self.ema.ema).half(),
|
||||||
|
"updates": self.ema.updates,
|
||||||
|
"optimizer": self.optimizer.state_dict(),
|
||||||
|
"train_args": vars(self.args), # save as dict
|
||||||
|
"train_metrics": metrics,
|
||||||
|
"train_results": results,
|
||||||
|
"date": datetime.now().isoformat(),
|
||||||
|
"version": __version__,
|
||||||
|
"license": "AGPL-3.0 (https://ultralytics.com/license)",
|
||||||
|
"docs": "https://docs.ultralytics.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save last and best
|
||||||
|
torch.save(ckpt, self.last)
|
||||||
|
if self.best_fitness == self.fitness:
|
||||||
|
torch.save(ckpt, self.best)
|
||||||
|
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
||||||
|
torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_dataset(data):
|
||||||
|
"""
|
||||||
|
Get train, val path from data dict if it exists.
|
||||||
|
|
||||||
|
Returns None if data format is not recognized.
|
||||||
|
"""
|
||||||
|
return data["train"], data.get("val") or data.get("test")
|
||||||
|
|
||||||
|
def setup_model(self):
|
||||||
|
"""Load/create/download model for any task."""
|
||||||
|
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||||
|
return
|
||||||
|
|
||||||
|
model, weights = self.model, None
|
||||||
|
ckpt = None
|
||||||
|
if str(model).endswith(".pt"):
|
||||||
|
weights, ckpt = attempt_load_one_weight(model)
|
||||||
|
cfg = ckpt["model"].yaml
|
||||||
|
else:
|
||||||
|
cfg = model
|
||||||
|
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
||||||
|
return ckpt
|
||||||
|
|
||||||
|
def optimizer_step(self):
|
||||||
|
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
|
||||||
|
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
if self.ema:
|
||||||
|
self.ema.update(self.model)
|
||||||
|
|
||||||
|
def preprocess_batch(self, batch):
|
||||||
|
"""Allows custom preprocessing model inputs and ground truths depending on task type."""
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""
|
||||||
|
Runs validation on test set using self.validator.
|
||||||
|
|
||||||
|
The returned dict is expected to contain "fitness" key.
|
||||||
|
"""
|
||||||
|
metrics = self.validator(self)
|
||||||
|
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||||
|
if not self.best_fitness or self.best_fitness < fitness:
|
||||||
|
self.best_fitness = fitness
|
||||||
|
return metrics, fitness
|
||||||
|
|
||||||
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
|
"""Get model and raise NotImplementedError for loading cfg files."""
|
||||||
|
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||||
|
|
||||||
|
def get_validator(self):
|
||||||
|
"""Returns a NotImplementedError when the get_validator function is called."""
|
||||||
|
raise NotImplementedError("get_validator function not implemented in trainer")
|
||||||
|
|
||||||
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||||
|
"""Returns dataloader derived from torch.data.Dataloader."""
|
||||||
|
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
||||||
|
|
||||||
|
def build_dataset(self, img_path, mode="train", batch=None):
|
||||||
|
"""Build dataset."""
|
||||||
|
raise NotImplementedError("build_dataset function not implemented in trainer")
|
||||||
|
|
||||||
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||||
|
"""
|
||||||
|
Returns a loss dict with labelled training loss items tensor.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This is not needed for classification but necessary for segmentation & detection
|
||||||
|
"""
|
||||||
|
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
||||||
|
|
||||||
|
def set_model_attributes(self):
|
||||||
|
"""To set or update model parameters before training."""
|
||||||
|
self.model.names = self.data["names"]
|
||||||
|
|
||||||
|
def build_targets(self, preds, targets):
|
||||||
|
"""Builds target tensors for training YOLO model."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def progress_string(self):
|
||||||
|
"""Returns a string describing training progress."""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# TODO: may need to put these following functions into callback
|
||||||
|
def plot_training_samples(self, batch, ni):
|
||||||
|
"""Plots training samples during YOLO training."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def plot_training_labels(self):
|
||||||
|
"""Plots training labels for YOLO model."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_metrics(self, metrics):
|
||||||
|
"""Saves training metrics to a CSV file."""
|
||||||
|
keys, vals = list(metrics.keys()), list(metrics.values())
|
||||||
|
n = len(metrics) + 1 # number of cols
|
||||||
|
s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
|
||||||
|
with open(self.csv, "a") as f:
|
||||||
|
f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
|
||||||
|
|
||||||
|
def plot_metrics(self):
|
||||||
|
"""Plot and display metrics visually."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_plot(self, name, data=None):
|
||||||
|
"""Registers plots (e.g. to be consumed in callbacks)"""
|
||||||
|
path = Path(name)
|
||||||
|
self.plots[path] = {"data": data, "timestamp": time.time()}
|
||||||
|
|
||||||
|
def final_eval(self):
|
||||||
|
"""Performs final evaluation and validation for object detection YOLO model."""
|
||||||
|
for f in self.last, self.best:
|
||||||
|
if f.exists():
|
||||||
|
strip_optimizer(f) # strip optimizers
|
||||||
|
if f is self.best:
|
||||||
|
LOGGER.info(f"\nValidating {f}...")
|
||||||
|
self.validator.args.plots = self.args.plots
|
||||||
|
self.metrics = self.validator(model=f)
|
||||||
|
self.metrics.pop("fitness", None)
|
||||||
|
self.run_callbacks("on_fit_epoch_end")
|
||||||
|
|
||||||
|
def check_resume(self, overrides):
|
||||||
|
"""Check if resume checkpoint exists and update arguments accordingly."""
|
||||||
|
resume = self.args.resume
|
||||||
|
if resume:
|
||||||
|
try:
|
||||||
|
exists = isinstance(resume, (str, Path)) and Path(resume).exists()
|
||||||
|
last = Path(check_file(resume) if exists else get_latest_run())
|
||||||
|
|
||||||
|
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
|
||||||
|
ckpt_args = attempt_load_weights(last).args
|
||||||
|
if not Path(ckpt_args["data"]).exists():
|
||||||
|
ckpt_args["data"] = self.args.data
|
||||||
|
|
||||||
|
resume = True
|
||||||
|
self.args = get_cfg(ckpt_args)
|
||||||
|
self.args.model = self.args.resume = str(last) # reinstate model
|
||||||
|
for k in "imgsz", "batch", "device": # allow arg updates to reduce memory or update device on resume
|
||||||
|
if k in overrides:
|
||||||
|
setattr(self.args, k, overrides[k])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
|
||||||
|
"i.e. 'yolo train resume model=path/to/last.pt'"
|
||||||
|
) from e
|
||||||
|
self.resume = resume
|
||||||
|
|
||||||
|
def resume_training(self, ckpt):
|
||||||
|
"""Resume YOLO training from given epoch and best fitness."""
|
||||||
|
if ckpt is None or not self.resume:
|
||||||
|
return
|
||||||
|
best_fitness = 0.0
|
||||||
|
start_epoch = ckpt["epoch"] + 1
|
||||||
|
if ckpt["optimizer"] is not None:
|
||||||
|
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
||||||
|
best_fitness = ckpt["best_fitness"]
|
||||||
|
if self.ema and ckpt.get("ema"):
|
||||||
|
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
|
||||||
|
self.ema.updates = ckpt["updates"]
|
||||||
|
assert start_epoch > 0, (
|
||||||
|
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
|
||||||
|
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
||||||
|
)
|
||||||
|
LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
|
||||||
|
if self.epochs < start_epoch:
|
||||||
|
LOGGER.info(
|
||||||
|
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
||||||
|
)
|
||||||
|
self.epochs += ckpt["epoch"] # finetune additional epochs
|
||||||
|
self.best_fitness = best_fitness
|
||||||
|
self.start_epoch = start_epoch
|
||||||
|
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||||
|
self._close_dataloader_mosaic()
|
||||||
|
|
||||||
|
def _close_dataloader_mosaic(self):
|
||||||
|
"""Update dataloaders to stop using mosaic augmentation."""
|
||||||
|
if hasattr(self.train_loader.dataset, "mosaic"):
|
||||||
|
self.train_loader.dataset.mosaic = False
|
||||||
|
if hasattr(self.train_loader.dataset, "close_mosaic"):
|
||||||
|
LOGGER.info("Closing dataloader mosaic")
|
||||||
|
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||||
|
|
||||||
|
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
||||||
|
"""
|
||||||
|
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
|
||||||
|
weight decay, and number of iterations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module): The model for which to build an optimizer.
|
||||||
|
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
|
||||||
|
based on the number of iterations. Default: 'auto'.
|
||||||
|
lr (float, optional): The learning rate for the optimizer. Default: 0.001.
|
||||||
|
momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
|
||||||
|
decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
|
||||||
|
iterations (float, optional): The number of iterations, which determines the optimizer if
|
||||||
|
name is 'auto'. Default: 1e5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(torch.optim.Optimizer): The constructed optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
g = [], [], [] # optimizer parameter groups
|
||||||
|
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
||||||
|
if name == "auto":
|
||||||
|
LOGGER.info(
|
||||||
|
f"{colorstr('optimizer:')} 'optimizer=auto' found, "
|
||||||
|
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
|
||||||
|
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
|
||||||
|
)
|
||||||
|
nc = getattr(model, "nc", 10) # number of classes
|
||||||
|
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
|
||||||
|
name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
|
||||||
|
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
|
||||||
|
|
||||||
|
for module_name, module in model.named_modules():
|
||||||
|
for param_name, param in module.named_parameters(recurse=False):
|
||||||
|
fullname = f"{module_name}.{param_name}" if module_name else param_name
|
||||||
|
if "bias" in fullname: # bias (no decay)
|
||||||
|
g[2].append(param)
|
||||||
|
elif isinstance(module, bn): # weight (no decay)
|
||||||
|
g[1].append(param)
|
||||||
|
else: # weight (with decay)
|
||||||
|
g[0].append(param)
|
||||||
|
|
||||||
|
if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
|
||||||
|
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
||||||
|
elif name == "RMSProp":
|
||||||
|
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
||||||
|
elif name == "SGD":
|
||||||
|
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Optimizer '{name}' not found in list of available optimizers "
|
||||||
|
f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]."
|
||||||
|
"To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics."
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
|
||||||
|
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
|
||||||
|
LOGGER.info(
|
||||||
|
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
||||||
|
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)'
|
||||||
|
)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def plot_grads(batch, obj, i, nsamples=16):
|
||||||
|
# Create a tensor of zeros with the same size as images
|
||||||
|
images = batch['img'].requires_grad_(True)
|
||||||
|
mask = torch.zeros_like(images, dtype=torch.float32)
|
||||||
|
smask = get_dist_reg(images, batch['masks'])
|
||||||
|
loss, loss_items = obj.model(batch)
|
||||||
|
grad = torch.autograd.grad(loss, images, retain_graph=True)[0]
|
||||||
|
grad = torch.abs(grad)
|
||||||
|
|
||||||
|
batch_size = images.shape[0]
|
||||||
|
imgsz = torch.tensor(batch['resized_shape'][0]).to(obj.device)
|
||||||
|
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||||
|
targets = v8DetectionLoss.preprocess(obj, targets=targets.to(obj.device), batch_size=batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||||
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||||
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||||
|
|
||||||
|
# Iterate over each bounding box and set the corresponding pixels to 1
|
||||||
|
for irx, bboxes in enumerate(gt_bboxes):
|
||||||
|
for idx in range(len(bboxes)):
|
||||||
|
x1, y1, x2, y2 = bboxes[idx]
|
||||||
|
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
|
||||||
|
mask[irx, :, y1:y2, x1:x2] = 1.0
|
||||||
|
|
||||||
|
save_imgs = True
|
||||||
|
if save_imgs:
|
||||||
|
# Convert tensors to numpy arrays
|
||||||
|
images_np = images.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
grad_np = grad.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
mask_np = mask.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
seg_mask_np = smask.detach().cpu().numpy().transpose(0, 2, 3, 1)
|
||||||
|
|
||||||
|
# Normalize grad for visualization
|
||||||
|
grad_np = (grad_np - grad_np.min()) / (grad_np.max() - grad_np.min())
|
||||||
|
|
||||||
|
range_val = min(nsamples, images_np.shape[0])
|
||||||
|
for ix in range(range_val):
|
||||||
|
fig, ax = plt.subplots(1, 6, figsize=(30, 5))
|
||||||
|
ax[0].imshow(images_np[ix])
|
||||||
|
ax[0].set_title('Image')
|
||||||
|
ax[1].imshow(grad_np[ix], cmap='jet')
|
||||||
|
ax[1].set_title('Gradient')
|
||||||
|
ax[2].imshow(images_np[ix])
|
||||||
|
ax[2].imshow(grad_np[ix], cmap='jet', alpha=0.5)
|
||||||
|
ax[2].set_title('Overlay')
|
||||||
|
ax[3].imshow(mask_np[ix], cmap='gray')
|
||||||
|
ax[3].set_title('Mask')
|
||||||
|
ax[4].imshow(seg_mask_np[ix], cmap='gray')
|
||||||
|
ax[4].set_title('Segmentation Mask')
|
||||||
|
|
||||||
|
# Plot image with bounding boxes
|
||||||
|
ax[5].imshow(images_np[ix])
|
||||||
|
for bbox, cls in zip(gt_bboxes[ix], gt_labels[ix]):
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
x1, y1, x2, y2 = int(torch.round(x1)), int(torch.round(y1)), int(torch.round(x2)), int(torch.round(y2))
|
||||||
|
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor=np.random.rand(3,), linewidth=2)
|
||||||
|
ax[5].add_patch(rect)
|
||||||
|
ax[5].text(x1, y1, f'{int(cls)}', color='white', fontsize=12, bbox=dict(facecolor='black', alpha=0.5))
|
||||||
|
ax[5].set_title('Bounding Boxes')
|
||||||
|
|
||||||
|
save_dir_attr = f"{obj.save_dir._str}/attributions"
|
||||||
|
if not os.path.exists(save_dir_attr):
|
||||||
|
os.makedirs(save_dir_attr)
|
||||||
|
plt.savefig(f'{save_dir_attr}/debug_epoch_{obj.epoch}_batch_{i}_image_{ix}.png')
|
||||||
|
plt.close(fig)
|
347
ultralytics/engine/pgt_validator.py
Normal file
347
ultralytics/engine/pgt_validator.py
Normal file
@ -0,0 +1,347 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
"""
|
||||||
|
Check a model's accuracy on a test or val split of a dataset.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
$ yolo mode=val model=yolov8n.pt data=coco128.yaml imgsz=640
|
||||||
|
|
||||||
|
Usage - formats:
|
||||||
|
$ yolo mode=val model=yolov8n.pt # PyTorch
|
||||||
|
yolov8n.torchscript # TorchScript
|
||||||
|
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
||||||
|
yolov8n_openvino_model # OpenVINO
|
||||||
|
yolov8n.engine # TensorRT
|
||||||
|
yolov8n.mlpackage # CoreML (macOS-only)
|
||||||
|
yolov8n_saved_model # TensorFlow SavedModel
|
||||||
|
yolov8n.pb # TensorFlow GraphDef
|
||||||
|
yolov8n.tflite # TensorFlow Lite
|
||||||
|
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||||
|
yolov8n_paddle_model # PaddlePaddle
|
||||||
|
yolov8n_ncnn_model # NCNN
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics.cfg import get_cfg, get_save_dir
|
||||||
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||||
|
from ultralytics.nn.autobackend import AutoBackend
|
||||||
|
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
|
||||||
|
from ultralytics.utils.checks import check_imgsz
|
||||||
|
from ultralytics.utils.ops import Profile
|
||||||
|
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
|
||||||
|
|
||||||
|
|
||||||
|
class PGTValidator:
|
||||||
|
"""
|
||||||
|
BaseValidator.
|
||||||
|
|
||||||
|
A base class for creating validators.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
args (SimpleNamespace): Configuration for the validator.
|
||||||
|
dataloader (DataLoader): Dataloader to use for validation.
|
||||||
|
pbar (tqdm): Progress bar to update during validation.
|
||||||
|
model (nn.Module): Model to validate.
|
||||||
|
data (dict): Data dictionary.
|
||||||
|
device (torch.device): Device to use for validation.
|
||||||
|
batch_i (int): Current batch index.
|
||||||
|
training (bool): Whether the model is in training mode.
|
||||||
|
names (dict): Class names.
|
||||||
|
seen: Records the number of images seen so far during validation.
|
||||||
|
stats: Placeholder for statistics during validation.
|
||||||
|
confusion_matrix: Placeholder for a confusion matrix.
|
||||||
|
nc: Number of classes.
|
||||||
|
iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
|
||||||
|
jdict (dict): Dictionary to store JSON validation results.
|
||||||
|
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
|
||||||
|
batch processing times in milliseconds.
|
||||||
|
save_dir (Path): Directory to save results.
|
||||||
|
plots (dict): Dictionary to store plots for visualization.
|
||||||
|
callbacks (dict): Dictionary to store various callback functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||||
|
"""
|
||||||
|
Initializes a BaseValidator instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
||||||
|
save_dir (Path, optional): Directory to save results.
|
||||||
|
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||||
|
args (SimpleNamespace): Configuration for the validator.
|
||||||
|
_callbacks (dict): Dictionary to store various callback functions.
|
||||||
|
"""
|
||||||
|
self.args = get_cfg(overrides=args)
|
||||||
|
self.dataloader = dataloader
|
||||||
|
self.pbar = pbar
|
||||||
|
self.stride = None
|
||||||
|
self.data = None
|
||||||
|
self.device = None
|
||||||
|
self.batch_i = None
|
||||||
|
self.training = True
|
||||||
|
self.names = None
|
||||||
|
self.seen = None
|
||||||
|
self.stats = None
|
||||||
|
self.confusion_matrix = None
|
||||||
|
self.nc = None
|
||||||
|
self.iouv = None
|
||||||
|
self.jdict = None
|
||||||
|
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
||||||
|
|
||||||
|
self.save_dir = save_dir or get_save_dir(self.args)
|
||||||
|
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
if self.args.conf is None:
|
||||||
|
self.args.conf = 0.001 # default conf=0.001
|
||||||
|
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
|
||||||
|
|
||||||
|
self.plots = {}
|
||||||
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||||
|
|
||||||
|
# @smart_inference_mode()
|
||||||
|
def __call__(self, trainer=None, model=None):
|
||||||
|
"""Supports validation of a pre-trained model if passed or a model being trained if trainer is passed (trainer
|
||||||
|
gets priority).
|
||||||
|
"""
|
||||||
|
self.training = trainer is not None
|
||||||
|
augment = self.args.augment and (not self.training)
|
||||||
|
if self.training:
|
||||||
|
self.device = trainer.device
|
||||||
|
self.data = trainer.data
|
||||||
|
# self.args.half = self.device.type != "cpu" # force FP16 val during training
|
||||||
|
model = trainer.ema.ema or trainer.model
|
||||||
|
model = model.half() if self.args.half else model.float()
|
||||||
|
# self.model = model
|
||||||
|
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
||||||
|
self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
|
||||||
|
model.eval()
|
||||||
|
else:
|
||||||
|
callbacks.add_integration_callbacks(self)
|
||||||
|
model = AutoBackend(
|
||||||
|
weights=model or self.args.model,
|
||||||
|
device=select_device(self.args.device, self.args.batch),
|
||||||
|
dnn=self.args.dnn,
|
||||||
|
data=self.args.data,
|
||||||
|
fp16=self.args.half,
|
||||||
|
)
|
||||||
|
# self.model = model
|
||||||
|
self.device = model.device # update device
|
||||||
|
self.args.half = model.fp16 # update half
|
||||||
|
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
||||||
|
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
||||||
|
if engine:
|
||||||
|
self.args.batch = model.batch_size
|
||||||
|
elif not pt and not jit:
|
||||||
|
self.args.batch = 1 # export.py models default to batch-size 1
|
||||||
|
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
|
||||||
|
|
||||||
|
if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
|
||||||
|
self.data = check_det_dataset(self.args.data)
|
||||||
|
elif self.args.task == "classify":
|
||||||
|
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||||
|
|
||||||
|
if self.device.type in ("cpu", "mps"):
|
||||||
|
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||||
|
if not pt:
|
||||||
|
self.args.rect = False
|
||||||
|
self.stride = model.stride # used in get_dataloader() for padding
|
||||||
|
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
|
||||||
|
|
||||||
|
self.run_callbacks("on_val_start")
|
||||||
|
dt = (
|
||||||
|
Profile(device=self.device),
|
||||||
|
Profile(device=self.device),
|
||||||
|
Profile(device=self.device),
|
||||||
|
Profile(device=self.device),
|
||||||
|
)
|
||||||
|
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
|
||||||
|
self.init_metrics(de_parallel(model))
|
||||||
|
self.jdict = [] # empty before each val
|
||||||
|
for batch_i, batch in enumerate(bar):
|
||||||
|
self.run_callbacks("on_val_batch_start")
|
||||||
|
self.batch_i = batch_i
|
||||||
|
# Preprocess
|
||||||
|
with dt[0]:
|
||||||
|
batch = self.preprocess(batch)
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
with dt[1]:
|
||||||
|
model.zero_grad()
|
||||||
|
preds = model(batch["img"].requires_grad_(True), augment=augment)
|
||||||
|
|
||||||
|
# Loss
|
||||||
|
with dt[2]:
|
||||||
|
if self.training:
|
||||||
|
self.loss += model.loss(batch, preds)[1]
|
||||||
|
model.zero_grad()
|
||||||
|
|
||||||
|
# Postprocess
|
||||||
|
with dt[3]:
|
||||||
|
preds = self.postprocess(preds)
|
||||||
|
|
||||||
|
self.update_metrics(preds, batch)
|
||||||
|
if self.args.plots and batch_i < 3:
|
||||||
|
self.plot_val_samples(batch, batch_i)
|
||||||
|
self.plot_predictions(batch, preds, batch_i)
|
||||||
|
|
||||||
|
self.run_callbacks("on_val_batch_end")
|
||||||
|
stats = self.get_stats()
|
||||||
|
self.check_stats(stats)
|
||||||
|
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
|
||||||
|
self.finalize_metrics()
|
||||||
|
if not (self.args.save_json and self.is_coco and len(self.jdict)):
|
||||||
|
self.print_results()
|
||||||
|
self.run_callbacks("on_val_end")
|
||||||
|
if self.training:
|
||||||
|
model.float()
|
||||||
|
if self.args.save_json and self.jdict:
|
||||||
|
with open(str(self.save_dir / "predictions.json"), "w") as f:
|
||||||
|
LOGGER.info(f"Saving {f.name}...")
|
||||||
|
json.dump(self.jdict, f) # flatten and save
|
||||||
|
stats = self.eval_json(stats) # update stats
|
||||||
|
stats['fitness'] = stats['metrics/mAP50-95(B)']
|
||||||
|
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
||||||
|
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||||
|
else:
|
||||||
|
LOGGER.info(
|
||||||
|
"Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
|
||||||
|
% tuple(self.speed.values())
|
||||||
|
)
|
||||||
|
if self.args.save_json and self.jdict:
|
||||||
|
with open(str(self.save_dir / "predictions.json"), "w") as f:
|
||||||
|
LOGGER.info(f"Saving {f.name}...")
|
||||||
|
json.dump(self.jdict, f) # flatten and save
|
||||||
|
stats = self.eval_json(stats) # update stats
|
||||||
|
if self.args.plots or self.args.save_json:
|
||||||
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False):
|
||||||
|
"""
|
||||||
|
Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred_classes (torch.Tensor): Predicted class indices of shape(N,).
|
||||||
|
true_classes (torch.Tensor): Target class indices of shape(M,).
|
||||||
|
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth
|
||||||
|
use_scipy (bool): Whether to use scipy for matching (more precise).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
|
||||||
|
"""
|
||||||
|
# Dx10 matrix, where D - detections, 10 - IoU thresholds
|
||||||
|
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
|
||||||
|
# LxD matrix where L - labels (rows), D - detections (columns)
|
||||||
|
correct_class = true_classes[:, None] == pred_classes
|
||||||
|
iou = iou * correct_class # zero out the wrong classes
|
||||||
|
iou = iou.cpu().numpy()
|
||||||
|
for i, threshold in enumerate(self.iouv.cpu().tolist()):
|
||||||
|
if use_scipy:
|
||||||
|
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
|
||||||
|
import scipy # scope import to avoid importing for all commands
|
||||||
|
|
||||||
|
cost_matrix = iou * (iou >= threshold)
|
||||||
|
if cost_matrix.any():
|
||||||
|
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
|
||||||
|
valid = cost_matrix[labels_idx, detections_idx] > 0
|
||||||
|
if valid.any():
|
||||||
|
correct[detections_idx[valid], i] = True
|
||||||
|
else:
|
||||||
|
matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match
|
||||||
|
matches = np.array(matches).T
|
||||||
|
if matches.shape[0]:
|
||||||
|
if matches.shape[0] > 1:
|
||||||
|
matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
|
||||||
|
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||||
|
# matches = matches[matches[:, 2].argsort()[::-1]]
|
||||||
|
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||||
|
correct[matches[:, 1].astype(int), i] = True
|
||||||
|
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
|
||||||
|
|
||||||
|
def add_callback(self, event: str, callback):
|
||||||
|
"""Appends the given callback."""
|
||||||
|
self.callbacks[event].append(callback)
|
||||||
|
|
||||||
|
def run_callbacks(self, event: str):
|
||||||
|
"""Runs all callbacks associated with a specified event."""
|
||||||
|
for callback in self.callbacks.get(event, []):
|
||||||
|
callback(self)
|
||||||
|
|
||||||
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
|
"""Get data loader from dataset path and batch size."""
|
||||||
|
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
||||||
|
|
||||||
|
def build_dataset(self, img_path):
|
||||||
|
"""Build dataset."""
|
||||||
|
raise NotImplementedError("build_dataset function not implemented in validator")
|
||||||
|
|
||||||
|
def preprocess(self, batch):
|
||||||
|
"""Preprocesses an input batch."""
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def postprocess(self, preds):
|
||||||
|
"""Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
|
||||||
|
return preds
|
||||||
|
|
||||||
|
def init_metrics(self, model):
|
||||||
|
"""Initialize performance metrics for the YOLO model."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def update_metrics(self, preds, batch):
|
||||||
|
"""Updates metrics based on predictions and batch."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def finalize_metrics(self, *args, **kwargs):
|
||||||
|
"""Finalizes and returns all metrics."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
"""Returns statistics about the model's performance."""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def check_stats(self, stats):
|
||||||
|
"""Checks statistics."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def print_results(self):
|
||||||
|
"""Prints the results of the model's predictions."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_desc(self):
|
||||||
|
"""Get description of the YOLO model."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metric_keys(self):
|
||||||
|
"""Returns the metric keys used in YOLO training/validation."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def on_plot(self, name, data=None):
|
||||||
|
"""Registers plots (e.g. to be consumed in callbacks)"""
|
||||||
|
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
||||||
|
|
||||||
|
# TODO: may need to put these following functions into callback
|
||||||
|
def plot_val_samples(self, batch, ni):
|
||||||
|
"""Plots validation samples during training."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def plot_predictions(self, batch, preds, ni):
|
||||||
|
"""Plots YOLO model predictions on batch images."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def pred_to_json(self, preds, batch):
|
||||||
|
"""Convert predictions to JSON format."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def eval_json(self, stats):
|
||||||
|
"""Evaluate and return JSON format of prediction statistics."""
|
||||||
|
pass
|
@ -206,7 +206,7 @@ class BasePredictor:
|
|||||||
self.vid_writer = {}
|
self.vid_writer = {}
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
def stream_inference(self, source=None, model=None, return_images = False, *args, **kwargs):
|
||||||
"""Streams real-time inference on camera feed and saves results to file."""
|
"""Streams real-time inference on camera feed and saves results to file."""
|
||||||
if self.args.verbose:
|
if self.args.verbose:
|
||||||
LOGGER.info("")
|
LOGGER.info("")
|
||||||
@ -243,6 +243,9 @@ class BasePredictor:
|
|||||||
with profilers[0]:
|
with profilers[0]:
|
||||||
im = self.preprocess(im0s)
|
im = self.preprocess(im0s)
|
||||||
|
|
||||||
|
if return_images:
|
||||||
|
im = im.requires_grad_(True)
|
||||||
|
|
||||||
# Inference
|
# Inference
|
||||||
with profilers[1]:
|
with profilers[1]:
|
||||||
preds = self.inference(im, *args, **kwargs)
|
preds = self.inference(im, *args, **kwargs)
|
||||||
@ -272,7 +275,7 @@ class BasePredictor:
|
|||||||
LOGGER.info("\n".join(s))
|
LOGGER.info("\n".join(s))
|
||||||
|
|
||||||
self.run_callbacks("on_predict_batch_end")
|
self.run_callbacks("on_predict_batch_end")
|
||||||
yield from self.results
|
yield from (self.results, im)
|
||||||
|
|
||||||
# Release assets
|
# Release assets
|
||||||
for v in self.vid_writer.values():
|
for v in self.vid_writer.values():
|
||||||
|
@ -3,6 +3,6 @@
|
|||||||
from .rtdetr import RTDETR
|
from .rtdetr import RTDETR
|
||||||
from .sam import SAM
|
from .sam import SAM
|
||||||
from .yolo import YOLO, YOLOWorld
|
from .yolo import YOLO, YOLOWorld
|
||||||
from .yolov10 import YOLOv10
|
from .yolov10 import YOLOv10, YOLOv10PGT
|
||||||
|
|
||||||
__all__ = "YOLO", "RTDETR", "SAM", "YOLOWorld", "YOLOv10" # allow simpler import
|
__all__ = "YOLO", "RTDETR", "SAM", "YOLOWorld", "YOLOv10", "YOLOv10PGT" # allow simpler import
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
from .predict import DetectionPredictor
|
from .predict import DetectionPredictor
|
||||||
|
from .pgt_train import PGTDetectionTrainer
|
||||||
from .train import DetectionTrainer
|
from .train import DetectionTrainer
|
||||||
from .val import DetectionValidator
|
from .val import DetectionValidator
|
||||||
|
from .pgt_val import PGTDetectionValidator
|
||||||
|
|
||||||
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"
|
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator", "PGTDetectionTrainer", "PGTDetectionValidator"
|
||||||
|
144
ultralytics/models/yolo/detect/pgt_train.py
Normal file
144
ultralytics/models/yolo/detect/pgt_train.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ultralytics.data import build_dataloader, build_yolo_dataset
|
||||||
|
from ultralytics.engine.trainer import BaseTrainer
|
||||||
|
from ultralytics.engine.pgt_trainer import PGTTrainer
|
||||||
|
from ultralytics.models import yolo
|
||||||
|
from ultralytics.nn.tasks import DetectionModel
|
||||||
|
from ultralytics.utils import LOGGER, RANK
|
||||||
|
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
|
||||||
|
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
|
||||||
|
|
||||||
|
|
||||||
|
class PGTDetectionTrainer(PGTTrainer):
|
||||||
|
"""
|
||||||
|
A class extending the BaseTrainer class for training based on a detection model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from ultralytics.models.yolo.detect import DetectionTrainer
|
||||||
|
|
||||||
|
args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
|
||||||
|
trainer = DetectionTrainer(overrides=args)
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def build_dataset(self, img_path, mode="train", batch=None):
|
||||||
|
"""
|
||||||
|
Build YOLO Dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_path (str): Path to the folder containing images.
|
||||||
|
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||||
|
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||||
|
"""
|
||||||
|
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||||
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
||||||
|
|
||||||
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||||
|
"""Construct and return dataloader."""
|
||||||
|
assert mode in ["train", "val"]
|
||||||
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||||
|
dataset = self.build_dataset(dataset_path, mode, batch_size)
|
||||||
|
shuffle = mode == "train"
|
||||||
|
if getattr(dataset, "rect", False) and shuffle:
|
||||||
|
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||||
|
shuffle = False
|
||||||
|
workers = self.args.workers if mode == "train" else self.args.workers * 2
|
||||||
|
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
|
||||||
|
|
||||||
|
def preprocess_batch(self, batch):
|
||||||
|
"""Preprocesses a batch of images by scaling and converting to float."""
|
||||||
|
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||||
|
if self.args.multi_scale:
|
||||||
|
imgs = batch["img"]
|
||||||
|
sz = (
|
||||||
|
random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride)
|
||||||
|
// self.stride
|
||||||
|
* self.stride
|
||||||
|
) # size
|
||||||
|
sf = sz / max(imgs.shape[2:]) # scale factor
|
||||||
|
if sf != 1:
|
||||||
|
ns = [
|
||||||
|
math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
|
||||||
|
] # new shape (stretched to gs-multiple)
|
||||||
|
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||||
|
batch["img"] = imgs
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def set_model_attributes(self):
|
||||||
|
"""Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
|
||||||
|
# self.args.box *= 3 / nl # scale to layers
|
||||||
|
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
||||||
|
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||||
|
self.model.nc = self.data["nc"] # attach number of classes to model
|
||||||
|
self.model.names = self.data["names"] # attach class names to model
|
||||||
|
self.model.args = self.args # attach hyperparameters to model
|
||||||
|
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
||||||
|
|
||||||
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
|
"""Return a YOLO detection model."""
|
||||||
|
model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
|
if weights:
|
||||||
|
model.load(weights)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def get_validator(self):
|
||||||
|
"""Returns a DetectionValidator for YOLO model validation."""
|
||||||
|
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
|
||||||
|
return yolo.detect.DetectionValidator(
|
||||||
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||||
|
"""
|
||||||
|
Returns a loss dict with labelled training loss items tensor.
|
||||||
|
|
||||||
|
Not needed for classification but necessary for segmentation & detection
|
||||||
|
"""
|
||||||
|
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||||
|
if loss_items is not None:
|
||||||
|
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
|
||||||
|
return dict(zip(keys, loss_items))
|
||||||
|
else:
|
||||||
|
return keys
|
||||||
|
|
||||||
|
def progress_string(self):
|
||||||
|
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
|
||||||
|
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
||||||
|
"Epoch",
|
||||||
|
"GPU_mem",
|
||||||
|
*self.loss_names,
|
||||||
|
"Instances",
|
||||||
|
"Size",
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot_training_samples(self, batch, ni):
|
||||||
|
"""Plots training samples with their annotations."""
|
||||||
|
plot_images(
|
||||||
|
images=batch["img"],
|
||||||
|
batch_idx=batch["batch_idx"],
|
||||||
|
cls=batch["cls"].squeeze(-1),
|
||||||
|
bboxes=batch["bboxes"],
|
||||||
|
paths=batch["im_file"],
|
||||||
|
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||||
|
on_plot=self.on_plot,
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot_metrics(self):
|
||||||
|
"""Plots metrics from a CSV file."""
|
||||||
|
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
|
||||||
|
|
||||||
|
def plot_training_labels(self):
|
||||||
|
"""Create a labeled training plot of the YOLO model."""
|
||||||
|
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
|
||||||
|
cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
|
||||||
|
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
|
300
ultralytics/models/yolo/detect/pgt_val.py
Normal file
300
ultralytics/models/yolo/detect/pgt_val.py
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
|
||||||
|
from ultralytics.engine.validator import BaseValidator
|
||||||
|
from ultralytics.engine.pgt_validator import PGTValidator
|
||||||
|
from ultralytics.utils import LOGGER, ops
|
||||||
|
from ultralytics.utils.checks import check_requirements
|
||||||
|
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
||||||
|
from ultralytics.utils.plotting import output_to_target, plot_images
|
||||||
|
|
||||||
|
|
||||||
|
class PGTDetectionValidator(PGTValidator):
|
||||||
|
"""
|
||||||
|
A class extending the BaseValidator class for validation based on a detection model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from ultralytics.models.yolo.detect import DetectionValidator
|
||||||
|
|
||||||
|
args = dict(model='yolov8n.pt', data='coco8.yaml')
|
||||||
|
validator = DetectionValidator(args=args)
|
||||||
|
validator()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||||
|
"""Initialize detection model with necessary variables and settings."""
|
||||||
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||||
|
self.nt_per_class = None
|
||||||
|
self.is_coco = False
|
||||||
|
self.class_map = None
|
||||||
|
self.args.task = "detect"
|
||||||
|
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||||
|
self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
|
||||||
|
self.niou = self.iouv.numel()
|
||||||
|
self.lb = [] # for autolabelling
|
||||||
|
|
||||||
|
def preprocess(self, batch):
|
||||||
|
"""Preprocesses batch of images for YOLO training."""
|
||||||
|
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
||||||
|
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
||||||
|
for k in ["batch_idx", "cls", "bboxes"]:
|
||||||
|
batch[k] = batch[k].to(self.device)
|
||||||
|
|
||||||
|
if self.args.save_hybrid:
|
||||||
|
height, width = batch["img"].shape[2:]
|
||||||
|
nb = len(batch["img"])
|
||||||
|
bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device)
|
||||||
|
self.lb = (
|
||||||
|
[
|
||||||
|
torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1)
|
||||||
|
for i in range(nb)
|
||||||
|
]
|
||||||
|
if self.args.save_hybrid
|
||||||
|
else []
|
||||||
|
) # for autolabelling
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def init_metrics(self, model):
|
||||||
|
"""Initialize evaluation metrics for YOLO."""
|
||||||
|
val = self.data.get(self.args.split, "") # validation path
|
||||||
|
self.is_coco = isinstance(val, str) and "coco" in val and val.endswith(f"{os.sep}val2017.txt") # is COCO
|
||||||
|
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
||||||
|
self.args.save_json |= self.is_coco # run on final val if training COCO
|
||||||
|
self.names = model.names
|
||||||
|
self.nc = len(model.names)
|
||||||
|
self.metrics.names = self.names
|
||||||
|
self.metrics.plot = self.args.plots
|
||||||
|
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
|
||||||
|
self.seen = 0
|
||||||
|
self.jdict = []
|
||||||
|
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[])
|
||||||
|
|
||||||
|
def get_desc(self):
|
||||||
|
"""Return a formatted string summarizing class metrics of YOLO model."""
|
||||||
|
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
|
||||||
|
|
||||||
|
def postprocess(self, preds):
|
||||||
|
"""Apply Non-maximum suppression to prediction outputs."""
|
||||||
|
return ops.non_max_suppression(
|
||||||
|
preds,
|
||||||
|
self.args.conf,
|
||||||
|
self.args.iou,
|
||||||
|
labels=self.lb,
|
||||||
|
multi_label=True,
|
||||||
|
agnostic=self.args.single_cls,
|
||||||
|
max_det=self.args.max_det,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_batch(self, si, batch):
|
||||||
|
"""Prepares a batch of images and annotations for validation."""
|
||||||
|
idx = batch["batch_idx"] == si
|
||||||
|
cls = batch["cls"][idx].squeeze(-1)
|
||||||
|
bbox = batch["bboxes"][idx]
|
||||||
|
ori_shape = batch["ori_shape"][si]
|
||||||
|
imgsz = batch["img"].shape[2:]
|
||||||
|
ratio_pad = batch["ratio_pad"][si]
|
||||||
|
if len(cls):
|
||||||
|
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
|
||||||
|
ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels
|
||||||
|
return dict(cls=cls, bbox=bbox, ori_shape=ori_shape, imgsz=imgsz, ratio_pad=ratio_pad)
|
||||||
|
|
||||||
|
def _prepare_pred(self, pred, pbatch):
|
||||||
|
"""Prepares a batch of images and annotations for validation."""
|
||||||
|
predn = pred.clone()
|
||||||
|
ops.scale_boxes(
|
||||||
|
pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]
|
||||||
|
) # native-space pred
|
||||||
|
return predn
|
||||||
|
|
||||||
|
def update_metrics(self, preds, batch):
|
||||||
|
"""Metrics."""
|
||||||
|
for si, pred in enumerate(preds):
|
||||||
|
self.seen += 1
|
||||||
|
npr = len(pred)
|
||||||
|
stat = dict(
|
||||||
|
conf=torch.zeros(0, device=self.device),
|
||||||
|
pred_cls=torch.zeros(0, device=self.device),
|
||||||
|
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
|
||||||
|
)
|
||||||
|
pbatch = self._prepare_batch(si, batch)
|
||||||
|
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
|
||||||
|
nl = len(cls)
|
||||||
|
stat["target_cls"] = cls
|
||||||
|
if npr == 0:
|
||||||
|
if nl:
|
||||||
|
for k in self.stats.keys():
|
||||||
|
self.stats[k].append(stat[k])
|
||||||
|
if self.args.plots:
|
||||||
|
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Predictions
|
||||||
|
if self.args.single_cls:
|
||||||
|
pred[:, 5] = 0
|
||||||
|
predn = self._prepare_pred(pred, pbatch)
|
||||||
|
stat["conf"] = predn[:, 4]
|
||||||
|
stat["pred_cls"] = predn[:, 5]
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
if nl:
|
||||||
|
stat["tp"] = self._process_batch(predn, bbox, cls)
|
||||||
|
if self.args.plots:
|
||||||
|
self.confusion_matrix.process_batch(predn, bbox, cls)
|
||||||
|
for k in self.stats.keys():
|
||||||
|
self.stats[k].append(stat[k])
|
||||||
|
|
||||||
|
# Save
|
||||||
|
if self.args.save_json:
|
||||||
|
self.pred_to_json(predn, batch["im_file"][si])
|
||||||
|
if self.args.save_txt:
|
||||||
|
file = self.save_dir / "labels" / f'{Path(batch["im_file"][si]).stem}.txt'
|
||||||
|
self.save_one_txt(predn, self.args.save_conf, pbatch["ori_shape"], file)
|
||||||
|
|
||||||
|
def finalize_metrics(self, *args, **kwargs):
|
||||||
|
"""Set final values for metrics speed and confusion matrix."""
|
||||||
|
self.metrics.speed = self.speed
|
||||||
|
self.metrics.confusion_matrix = self.confusion_matrix
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
"""Returns metrics statistics and results dictionary."""
|
||||||
|
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
|
||||||
|
if len(stats) and stats["tp"].any():
|
||||||
|
self.metrics.process(**stats)
|
||||||
|
self.nt_per_class = np.bincount(
|
||||||
|
stats["target_cls"].astype(int), minlength=self.nc
|
||||||
|
) # number of targets per class
|
||||||
|
return self.metrics.results_dict
|
||||||
|
|
||||||
|
def print_results(self):
|
||||||
|
"""Prints training/validation set metrics per class."""
|
||||||
|
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
|
||||||
|
LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
||||||
|
if self.nt_per_class.sum() == 0:
|
||||||
|
LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels")
|
||||||
|
|
||||||
|
# Print results per class
|
||||||
|
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
|
||||||
|
for i, c in enumerate(self.metrics.ap_class_index):
|
||||||
|
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
|
||||||
|
|
||||||
|
if self.args.plots:
|
||||||
|
for normalize in True, False:
|
||||||
|
self.confusion_matrix.plot(
|
||||||
|
save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
||||||
|
"""
|
||||||
|
Return correct prediction matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
|
||||||
|
Each detection is of the format: x1, y1, x2, y2, conf, class.
|
||||||
|
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
|
||||||
|
Each label is of the format: class, x1, y1, x2, y2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
|
||||||
|
"""
|
||||||
|
iou = box_iou(gt_bboxes, detections[:, :4])
|
||||||
|
return self.match_predictions(detections[:, 5], gt_cls, iou)
|
||||||
|
|
||||||
|
def build_dataset(self, img_path, mode="val", batch=None):
|
||||||
|
"""
|
||||||
|
Build YOLO Dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_path (str): Path to the folder containing images.
|
||||||
|
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||||
|
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
|
||||||
|
"""
|
||||||
|
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
|
||||||
|
|
||||||
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
|
"""Construct and return dataloader."""
|
||||||
|
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
|
||||||
|
return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader
|
||||||
|
|
||||||
|
def plot_val_samples(self, batch, ni):
|
||||||
|
"""Plot validation image samples."""
|
||||||
|
plot_images(
|
||||||
|
batch["img"],
|
||||||
|
batch["batch_idx"],
|
||||||
|
batch["cls"].squeeze(-1),
|
||||||
|
batch["bboxes"],
|
||||||
|
paths=batch["im_file"],
|
||||||
|
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||||
|
names=self.names,
|
||||||
|
on_plot=self.on_plot,
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot_predictions(self, batch, preds, ni):
|
||||||
|
"""Plots predicted bounding boxes on input images and saves the result."""
|
||||||
|
plot_images(
|
||||||
|
batch["img"],
|
||||||
|
*output_to_target(preds, max_det=self.args.max_det),
|
||||||
|
paths=batch["im_file"],
|
||||||
|
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||||
|
names=self.names,
|
||||||
|
on_plot=self.on_plot,
|
||||||
|
) # pred
|
||||||
|
|
||||||
|
def save_one_txt(self, predn, save_conf, shape, file):
|
||||||
|
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
|
||||||
|
gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
|
||||||
|
for *xyxy, conf, cls in predn.tolist():
|
||||||
|
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
||||||
|
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
|
||||||
|
with open(file, "a") as f:
|
||||||
|
f.write(("%g " * len(line)).rstrip() % line + "\n")
|
||||||
|
|
||||||
|
def pred_to_json(self, predn, filename):
|
||||||
|
"""Serialize YOLO predictions to COCO json format."""
|
||||||
|
stem = Path(filename).stem
|
||||||
|
image_id = int(stem) if stem.isnumeric() else stem
|
||||||
|
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
||||||
|
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
||||||
|
for p, b in zip(predn.tolist(), box.tolist()):
|
||||||
|
self.jdict.append(
|
||||||
|
{
|
||||||
|
"image_id": image_id,
|
||||||
|
"category_id": self.class_map[int(p[5])],
|
||||||
|
"bbox": [round(x, 3) for x in b],
|
||||||
|
"score": round(p[4], 5),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def eval_json(self, stats):
|
||||||
|
"""Evaluates YOLO output in JSON format and returns performance statistics."""
|
||||||
|
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||||
|
anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations
|
||||||
|
pred_json = self.save_dir / "predictions.json" # predictions
|
||||||
|
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
|
||||||
|
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||||
|
check_requirements("pycocotools>=2.0.6")
|
||||||
|
from pycocotools.coco import COCO # noqa
|
||||||
|
from pycocotools.cocoeval import COCOeval # noqa
|
||||||
|
|
||||||
|
for x in anno_json, pred_json:
|
||||||
|
assert x.is_file(), f"{x} file not found"
|
||||||
|
anno = COCO(str(anno_json)) # init annotations api
|
||||||
|
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
|
||||||
|
eval = COCOeval(anno, pred, "bbox")
|
||||||
|
if self.is_coco:
|
||||||
|
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||||
|
eval.evaluate()
|
||||||
|
eval.accumulate()
|
||||||
|
eval.summarize()
|
||||||
|
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.warning(f"pycocotools unable to run: {e}")
|
||||||
|
return stats
|
@ -3,5 +3,6 @@
|
|||||||
from .predict import SegmentationPredictor
|
from .predict import SegmentationPredictor
|
||||||
from .train import SegmentationTrainer
|
from .train import SegmentationTrainer
|
||||||
from .val import SegmentationValidator
|
from .val import SegmentationValidator
|
||||||
|
from .pgt_train import PGTSegmentationTrainer
|
||||||
|
|
||||||
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"
|
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator", "PGTSegmentationTrainer"
|
||||||
|
74
ultralytics/models/yolo/segment/pgt_train.py
Normal file
74
ultralytics/models/yolo/segment/pgt_train.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
from ultralytics.models import yolo
|
||||||
|
from ultralytics.nn.tasks import SegmentationModel, DetectionModel
|
||||||
|
from ultralytics.utils import DEFAULT_CFG, RANK
|
||||||
|
# from ultralytics.utils import yaml_load, IterableSimpleNamespace, ROOT
|
||||||
|
from ultralytics.utils.plotting import plot_images, plot_results
|
||||||
|
from ultralytics.models.yolov10.model import YOLOv10PGTDetectionModel
|
||||||
|
from ultralytics.models.yolov10.val import YOLOv10PGTDetectionValidator
|
||||||
|
|
||||||
|
# # Default configuration
|
||||||
|
# DEFAULT_CFG_DICT = yaml_load(ROOT / "cfg/pgt_train.yaml")
|
||||||
|
# for k, v in DEFAULT_CFG_DICT.items():
|
||||||
|
# if isinstance(v, str) and v.lower() == "none":
|
||||||
|
# DEFAULT_CFG_DICT[k] = None
|
||||||
|
# DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
|
||||||
|
# DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
|
||||||
|
|
||||||
|
class PGTSegmentationTrainer(yolo.detect.PGTDetectionTrainer):
|
||||||
|
"""
|
||||||
|
A class extending the DetectionTrainer class for training based on a segmentation model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
from ultralytics.models.yolo.segment import SegmentationTrainer
|
||||||
|
|
||||||
|
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
|
||||||
|
trainer = SegmentationTrainer(overrides=args)
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
|
"""Initialize a SegmentationTrainer object with given arguments."""
|
||||||
|
if overrides is None:
|
||||||
|
overrides = {}
|
||||||
|
overrides["task"] = "segment"
|
||||||
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
|
|
||||||
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
|
"""Return SegmentationModel initialized with specified config and weights."""
|
||||||
|
|
||||||
|
model = YOLOv10PGTDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
|
if weights:
|
||||||
|
model.load(weights)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def get_validator(self):
|
||||||
|
"""Return an instance of SegmentationValidator for validation of YOLO model."""
|
||||||
|
|
||||||
|
self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", "pgt_loss",
|
||||||
|
return YOLOv10PGTDetectionValidator(
|
||||||
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot_training_samples(self, batch, ni):
|
||||||
|
"""Creates a plot of training sample images with labels and box coordinates."""
|
||||||
|
plot_images(
|
||||||
|
batch["img"],
|
||||||
|
batch["batch_idx"],
|
||||||
|
batch["cls"].squeeze(-1),
|
||||||
|
batch["bboxes"],
|
||||||
|
masks=batch["masks"],
|
||||||
|
paths=batch["im_file"],
|
||||||
|
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||||
|
on_plot=self.on_plot,
|
||||||
|
)
|
||||||
|
|
||||||
|
def plot_metrics(self):
|
||||||
|
"""Plots training/val metrics."""
|
||||||
|
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
|
@ -1,5 +1,5 @@
|
|||||||
from .model import YOLOv10
|
from .model import YOLOv10, YOLOv10PGT
|
||||||
from .predict import YOLOv10DetectionPredictor
|
from .predict import YOLOv10DetectionPredictor
|
||||||
from .val import YOLOv10DetectionValidator
|
from .val import YOLOv10DetectionValidator
|
||||||
|
|
||||||
__all__ = "YOLOv10DetectionPredictor", "YOLOv10DetectionValidator", "YOLOv10"
|
__all__ = "YOLOv10DetectionPredictor", "YOLOv10DetectionValidator", "YOLOv10", "YOLOv10PGT"
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
from ultralytics.engine.model import Model
|
from ultralytics.engine.model import Model
|
||||||
from ultralytics.nn.tasks import YOLOv10DetectionModel
|
from ultralytics.nn.tasks import YOLOv10DetectionModel, YOLOv10PGTDetectionModel
|
||||||
from .val import YOLOv10DetectionValidator
|
from .val import YOLOv10DetectionValidator
|
||||||
from .predict import YOLOv10DetectionPredictor
|
from .predict import YOLOv10DetectionPredictor
|
||||||
from .train import YOLOv10DetectionTrainer
|
from .train import YOLOv10DetectionTrainer
|
||||||
|
from .pgt_train import YOLOv10PGTDetectionTrainer
|
||||||
|
# from ..yolo.segment import PGTSegmentationTrainer
|
||||||
|
# from .pgt_trainer import YOLOv10DetectionTrainer
|
||||||
|
|
||||||
from huggingface_hub import PyTorchModelHubMixin
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
from .card import card_template_text
|
from .card import card_template_text
|
||||||
@ -30,6 +33,39 @@ class YOLOv10(Model, PyTorchModelHubMixin, model_card_template=card_template_tex
|
|||||||
"detect": {
|
"detect": {
|
||||||
"model": YOLOv10DetectionModel,
|
"model": YOLOv10DetectionModel,
|
||||||
"trainer": YOLOv10DetectionTrainer,
|
"trainer": YOLOv10DetectionTrainer,
|
||||||
|
"pgt_trainer": YOLOv10PGTDetectionTrainer,
|
||||||
|
"validator": YOLOv10DetectionValidator,
|
||||||
|
"predictor": YOLOv10DetectionPredictor,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_pgt_segmentation_trainer():
|
||||||
|
from ..yolo.segment import PGTSegmentationTrainer
|
||||||
|
return PGTSegmentationTrainer
|
||||||
|
|
||||||
|
class YOLOv10PGT(Model, PyTorchModelHubMixin, model_card_template=card_template_text):
|
||||||
|
|
||||||
|
def __init__(self, model="yolov10n.pt", task=None, verbose=False,
|
||||||
|
names=None):
|
||||||
|
super().__init__(model=model, task=task, verbose=verbose)
|
||||||
|
if names is not None:
|
||||||
|
setattr(self.model, 'names', names)
|
||||||
|
|
||||||
|
def push_to_hub(self, repo_name, **kwargs):
|
||||||
|
config = kwargs.get('config', {})
|
||||||
|
config['names'] = self.names
|
||||||
|
config['model'] = self.model.yaml['yaml_file']
|
||||||
|
config['task'] = self.task
|
||||||
|
kwargs['config'] = config
|
||||||
|
super().push_to_hub(repo_name, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def task_map(self):
|
||||||
|
"""Map head to model, trainer, validator, and predictor classes."""
|
||||||
|
return {
|
||||||
|
"detect": {
|
||||||
|
"model": YOLOv10DetectionModel,
|
||||||
|
"trainer": _get_pgt_segmentation_trainer(),
|
||||||
"validator": YOLOv10DetectionValidator,
|
"validator": YOLOv10DetectionValidator,
|
||||||
"predictor": YOLOv10DetectionPredictor,
|
"predictor": YOLOv10DetectionPredictor,
|
||||||
},
|
},
|
||||||
|
21
ultralytics/models/yolov10/pgt_train.py
Normal file
21
ultralytics/models/yolov10/pgt_train.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from ultralytics.models.yolo.detect import DetectionTrainer
|
||||||
|
from ultralytics.models.yolo.detect import PGTDetectionTrainer
|
||||||
|
from .val import YOLOv10DetectionValidator
|
||||||
|
from .model import YOLOv10DetectionModel
|
||||||
|
from copy import copy
|
||||||
|
from ultralytics.utils import RANK
|
||||||
|
|
||||||
|
class YOLOv10PGTDetectionTrainer(PGTDetectionTrainer):
|
||||||
|
def get_validator(self):
|
||||||
|
"""Returns a DetectionValidator for YOLO model validation."""
|
||||||
|
self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo",
|
||||||
|
return YOLOv10DetectionValidator(
|
||||||
|
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
|
"""Return a YOLO detection model."""
|
||||||
|
model = YOLOv10DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
|
||||||
|
if weights:
|
||||||
|
model.load(weights)
|
||||||
|
return model
|
@ -1,4 +1,4 @@
|
|||||||
from ultralytics.models.yolo.detect import DetectionValidator
|
from ultralytics.models.yolo.detect import DetectionValidator, PGTDetectionValidator
|
||||||
from ultralytics.utils import ops
|
from ultralytics.utils import ops
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -22,3 +22,24 @@ class YOLOv10DetectionValidator(DetectionValidator):
|
|||||||
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc)
|
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc)
|
||||||
bboxes = ops.xywh2xyxy(boxes)
|
bboxes = ops.xywh2xyxy(boxes)
|
||||||
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
||||||
|
|
||||||
|
class YOLOv10PGTDetectionValidator(PGTDetectionValidator):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.args.save_json |= self.is_coco
|
||||||
|
|
||||||
|
def postprocess(self, preds):
|
||||||
|
if isinstance(preds, dict):
|
||||||
|
preds = preds["one2one"]
|
||||||
|
|
||||||
|
if isinstance(preds, (list, tuple)):
|
||||||
|
preds = preds[0]
|
||||||
|
|
||||||
|
# Acknowledgement: Thanks to sanha9999 in #190 and #181!
|
||||||
|
if preds.shape[-1] == 6:
|
||||||
|
return preds
|
||||||
|
else:
|
||||||
|
preds = preds.transpose(-1, -2)
|
||||||
|
boxes, scores, labels = ops.v10postprocess(preds, self.args.max_det, self.nc)
|
||||||
|
bboxes = ops.xywh2xyxy(boxes)
|
||||||
|
return torch.cat([bboxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
|
@ -57,7 +57,7 @@ from ultralytics.nn.modules import (
|
|||||||
)
|
)
|
||||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||||
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
||||||
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss
|
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss, v10DetectLoss, v10PGTDetectLoss
|
||||||
from ultralytics.utils.plotting import feature_visualization
|
from ultralytics.utils.plotting import feature_visualization
|
||||||
from ultralytics.utils.torch_utils import (
|
from ultralytics.utils.torch_utils import (
|
||||||
fuse_conv_and_bn,
|
fuse_conv_and_bn,
|
||||||
@ -93,7 +93,7 @@ class BaseModel(nn.Module):
|
|||||||
return self.loss(x, *args, **kwargs)
|
return self.loss(x, *args, **kwargs)
|
||||||
return self.predict(x, *args, **kwargs)
|
return self.predict(x, *args, **kwargs)
|
||||||
|
|
||||||
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
def predict(self, x, profile=False, visualize=False, augment=False, embed=None, return_images=False):
|
||||||
"""
|
"""
|
||||||
Perform a forward pass through the network.
|
Perform a forward pass through the network.
|
||||||
|
|
||||||
@ -107,9 +107,12 @@ class BaseModel(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
(torch.Tensor): The last output of the model.
|
(torch.Tensor): The last output of the model.
|
||||||
"""
|
"""
|
||||||
|
if return_images:
|
||||||
|
x = x.requires_grad_(True)
|
||||||
if augment:
|
if augment:
|
||||||
return self._predict_augment(x)
|
return self._predict_augment(x)
|
||||||
return self._predict_once(x, profile, visualize, embed)
|
out = self._predict_once(x, profile, visualize, embed)
|
||||||
|
return (out, x) if return_images else out
|
||||||
|
|
||||||
def _predict_once(self, x, profile=False, visualize=False, embed=None):
|
def _predict_once(self, x, profile=False, visualize=False, embed=None):
|
||||||
"""
|
"""
|
||||||
@ -140,13 +143,13 @@ class BaseModel(nn.Module):
|
|||||||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _predict_augment(self, x):
|
def _predict_augment(self, x, *args, **kwargs):
|
||||||
"""Perform augmentations on input image x and return augmented inference."""
|
"""Perform augmentations on input image x and return augmented inference."""
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. "
|
f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. "
|
||||||
f"Reverting to single-scale inference instead."
|
f"Reverting to single-scale inference instead."
|
||||||
)
|
)
|
||||||
return self._predict_once(x)
|
return self._predict_once(x, *args, **kwargs)
|
||||||
|
|
||||||
def _profile_one_layer(self, m, x, dt):
|
def _profile_one_layer(self, m, x, dt):
|
||||||
"""
|
"""
|
||||||
@ -260,7 +263,7 @@ class BaseModel(nn.Module):
|
|||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")
|
LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")
|
||||||
|
|
||||||
def loss(self, batch, preds=None):
|
def loss(self, batch, preds=None, return_images=False):
|
||||||
"""
|
"""
|
||||||
Compute loss.
|
Compute loss.
|
||||||
|
|
||||||
@ -271,8 +274,12 @@ class BaseModel(nn.Module):
|
|||||||
if not hasattr(self, "criterion"):
|
if not hasattr(self, "criterion"):
|
||||||
self.criterion = self.init_criterion()
|
self.criterion = self.init_criterion()
|
||||||
|
|
||||||
preds = self.forward(batch["img"]) if preds is None else preds
|
preds = self.forward(batch["img"], return_images=return_images) if preds is None else preds
|
||||||
return self.criterion(preds, batch)
|
if return_images:
|
||||||
|
preds, im = preds
|
||||||
|
loss = self.criterion(preds, batch)
|
||||||
|
out = loss if not return_images else (loss, im)
|
||||||
|
return out
|
||||||
|
|
||||||
def init_criterion(self):
|
def init_criterion(self):
|
||||||
"""Initialize the loss criterion for the BaseModel."""
|
"""Initialize the loss criterion for the BaseModel."""
|
||||||
@ -645,6 +652,10 @@ class YOLOv10DetectionModel(DetectionModel):
|
|||||||
def init_criterion(self):
|
def init_criterion(self):
|
||||||
return v10DetectLoss(self)
|
return v10DetectLoss(self)
|
||||||
|
|
||||||
|
class YOLOv10PGTDetectionModel(DetectionModel):
|
||||||
|
def init_criterion(self):
|
||||||
|
return v10PGTDetectLoss(self, pgt_coeff=self.args.pgt_coeff if hasattr(self.args, 'pgt_coeff') else None)
|
||||||
|
|
||||||
class Ensemble(nn.ModuleList):
|
class Ensemble(nn.ModuleList):
|
||||||
"""Ensemble of models."""
|
"""Ensemble of models."""
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
|||||||
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
||||||
from .metrics import bbox_iou, probiou
|
from .metrics import bbox_iou, probiou
|
||||||
from .tal import bbox2dist
|
from .tal import bbox2dist
|
||||||
|
from ultralytics.utils.plaus_functs import get_dist_reg, plaus_loss_fn
|
||||||
|
|
||||||
class VarifocalLoss(nn.Module):
|
class VarifocalLoss(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -725,3 +725,46 @@ class v10DetectLoss:
|
|||||||
one2one = preds["one2one"]
|
one2one = preds["one2one"]
|
||||||
loss_one2one = self.one2one(one2one, batch)
|
loss_one2one = self.one2one(one2one, batch)
|
||||||
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
|
return loss_one2many[0] + loss_one2one[0], torch.cat((loss_one2many[1], loss_one2one[1]))
|
||||||
|
|
||||||
|
class v10PGTDetectLoss:
|
||||||
|
def __init__(self, model, pgt_coeff):
|
||||||
|
self.one2many = v8DetectionLoss(model, tal_topk=10)
|
||||||
|
self.one2one = v8DetectionLoss(model, tal_topk=1)
|
||||||
|
self.pgt_coeff = pgt_coeff if pgt_coeff is not None else 2.0
|
||||||
|
|
||||||
|
def __call__(self, preds, batch, return_plaus=True, pgt_coeff=None):
|
||||||
|
if pgt_coeff is not None:
|
||||||
|
self.pgt_coeff = pgt_coeff
|
||||||
|
batch['img'] = batch['img'].requires_grad_(True)
|
||||||
|
one2many = preds["one2many"]
|
||||||
|
loss_one2many = self.one2many(one2many, batch)
|
||||||
|
one2one = preds["one2one"]
|
||||||
|
loss_one2one = self.one2one(one2one, batch)
|
||||||
|
|
||||||
|
loss = loss_one2many[0] + loss_one2one[0]
|
||||||
|
if return_plaus:
|
||||||
|
smask = get_dist_reg(batch['img'], batch['masks'])#.requires_grad_(True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
grad = torch.autograd.grad(loss, batch['img'],
|
||||||
|
retain_graph=True,
|
||||||
|
create_graph=True,
|
||||||
|
)[0]
|
||||||
|
except:
|
||||||
|
grad = torch.autograd.grad(loss, batch['img'],
|
||||||
|
retain_graph=True,
|
||||||
|
create_graph=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
|
||||||
|
grad = grad ** 2
|
||||||
|
|
||||||
|
plaus_loss = plaus_loss_fn(grad, smask, self.pgt_coeff)
|
||||||
|
# self.loss_items = torch.cat((self.loss_items, plaus_loss.unsqueeze(0)))
|
||||||
|
|
||||||
|
loss += plaus_loss
|
||||||
|
|
||||||
|
return loss, torch.cat((loss_one2many[1], loss_one2one[1], plaus_loss.unsqueeze(0)))
|
||||||
|
else:
|
||||||
|
return loss, torch.cat((loss_one2many[1], loss_one2one[1]))
|
||||||
|
|
844
ultralytics/utils/plaus_functs.py
Normal file
844
ultralytics/utils/plaus_functs.py
Normal file
@ -0,0 +1,844 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
# from plot_functs import *
|
||||||
|
from .plot_functs import normalize_tensor, overlay_mask, imshow
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import matplotlib.path as mplPath
|
||||||
|
from matplotlib.path import Path
|
||||||
|
# from utils.general import non_max_suppression, xyxy2xywh, scale_coords
|
||||||
|
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh, non_max_suppression
|
||||||
|
from .metrics import bbox_iou
|
||||||
|
import torchvision.transforms as T
|
||||||
|
|
||||||
|
def plaus_loss_fn(grad, smask, pgt_coeff, square=True):
|
||||||
|
################## Compute the PGT Loss ##################
|
||||||
|
# Positive regularization term for incentivizing pixels near the target to have high attribution
|
||||||
|
dist_attr_pos = attr_reg(grad, (1.0 - smask)) # dist_reg = seg_mask
|
||||||
|
# Negative regularization term for incentivizing pixels far from the target to have low attribution
|
||||||
|
dist_attr_neg = attr_reg(grad, smask)
|
||||||
|
# Calculate plausibility regularization term
|
||||||
|
# dist_reg = dist_attr_pos - dist_attr_neg
|
||||||
|
dist_reg = ((dist_attr_pos / torch.mean(grad)) - (dist_attr_neg / torch.mean(grad)))
|
||||||
|
plaus_reg = (((1.0 + dist_reg) / 2.0))
|
||||||
|
# Calculate plausibility loss
|
||||||
|
plaus_loss = ((1 - plaus_reg) ** 2 if square else (1 - plaus_reg)) * pgt_coeff
|
||||||
|
return plaus_loss
|
||||||
|
|
||||||
|
def get_dist_reg(images, seg_mask):
|
||||||
|
seg_mask = T.Resize((images.shape[2], images.shape[3]), antialias=True)(seg_mask).to(images.device)
|
||||||
|
seg_mask = seg_mask.to(dtype=torch.float32).unsqueeze(1).repeat(1, 3, 1, 1)
|
||||||
|
seg_mask[seg_mask > 0] = 1.0
|
||||||
|
|
||||||
|
smask = torch.zeros_like(seg_mask)
|
||||||
|
sigmas = [20.0 + (i_sig * 20.0) for i_sig in range(8)]
|
||||||
|
for k_it, sigma in enumerate(sigmas):
|
||||||
|
# Apply Gaussian blur to the mask
|
||||||
|
kernel_size = int(sigma + 50)
|
||||||
|
if kernel_size % 2 == 0:
|
||||||
|
kernel_size += 1
|
||||||
|
seg_mask1 = T.GaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma)(seg_mask)
|
||||||
|
if torch.max(seg_mask1) > 1.0:
|
||||||
|
# seg_mask1 = (seg_mask1 - seg_mask1.min()) / (seg_mask1.max() - seg_mask1.min())
|
||||||
|
seg_mask1 = normalize_tensor(seg_mask1)
|
||||||
|
smask = torch.max(smask, seg_mask1)
|
||||||
|
|
||||||
|
smask = normalize_tensor(smask)
|
||||||
|
return smask
|
||||||
|
|
||||||
|
def get_gradient(img, grad_wrt, norm=False, absolute=True, grayscale=False, keepmean=False):
|
||||||
|
"""
|
||||||
|
Compute the gradient of an image with respect to a given tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (torch.Tensor): The input image tensor.
|
||||||
|
grad_wrt (torch.Tensor): The tensor with respect to which the gradient is computed.
|
||||||
|
norm (bool, optional): Whether to normalize the gradient. Defaults to True.
|
||||||
|
absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
|
||||||
|
grayscale (bool, optional): Whether to convert the gradient to grayscale. Defaults to True.
|
||||||
|
keepmean (bool, optional): Whether to keep the mean value of the attribution map. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The computed attribution map.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if (grad_wrt.shape != torch.Size([1])) and (grad_wrt.shape != torch.Size([])):
|
||||||
|
grad_wrt_outputs = torch.ones_like(grad_wrt).clone().detach()#.requires_grad_(True)#.retains_grad_(True)
|
||||||
|
else:
|
||||||
|
grad_wrt_outputs = None
|
||||||
|
attribution_map = torch.autograd.grad(grad_wrt, img,
|
||||||
|
grad_outputs=grad_wrt_outputs,
|
||||||
|
create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
|
||||||
|
)[0]
|
||||||
|
if absolute:
|
||||||
|
attribution_map = torch.abs(attribution_map) # attribution_map ** 2 # Take absolute values of gradients
|
||||||
|
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
|
||||||
|
attribution_map = torch.sum(attribution_map, 1, keepdim=True)
|
||||||
|
if norm:
|
||||||
|
if keepmean:
|
||||||
|
attmean = torch.mean(attribution_map)
|
||||||
|
attmin = torch.min(attribution_map)
|
||||||
|
attmax = torch.max(attribution_map)
|
||||||
|
attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
|
||||||
|
if keepmean:
|
||||||
|
attribution_map -= attribution_map.mean()
|
||||||
|
attribution_map += (attmean / (attmax - attmin))
|
||||||
|
|
||||||
|
return attribution_map
|
||||||
|
|
||||||
|
def get_gaussian(img, grad_wrt, norm=True, absolute=True, grayscale=True, keepmean=False):
|
||||||
|
"""
|
||||||
|
Generate Gaussian noise based on the input image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (torch.Tensor): Input image.
|
||||||
|
grad_wrt: Gradient with respect to the input image.
|
||||||
|
norm (bool, optional): Whether to normalize the generated noise. Defaults to True.
|
||||||
|
absolute (bool, optional): Whether to take the absolute values of the gradients. Defaults to True.
|
||||||
|
grayscale (bool, optional): Whether to convert the noise to grayscale. Defaults to True.
|
||||||
|
keepmean (bool, optional): Whether to keep the mean of the noise. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Generated Gaussian noise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
gaussian_noise = torch.randn_like(img)
|
||||||
|
|
||||||
|
if absolute:
|
||||||
|
gaussian_noise = torch.abs(gaussian_noise) # Take absolute values of gradients
|
||||||
|
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
|
||||||
|
gaussian_noise = torch.sum(gaussian_noise, 1, keepdim=True)
|
||||||
|
if norm:
|
||||||
|
if keepmean:
|
||||||
|
attmean = torch.mean(gaussian_noise)
|
||||||
|
attmin = torch.min(gaussian_noise)
|
||||||
|
attmax = torch.max(gaussian_noise)
|
||||||
|
gaussian_noise = normalize_batch(gaussian_noise) # Normalize attribution maps per image in batch
|
||||||
|
if keepmean:
|
||||||
|
gaussian_noise -= gaussian_noise.mean()
|
||||||
|
gaussian_noise += (attmean / (attmax - attmin))
|
||||||
|
|
||||||
|
return gaussian_noise
|
||||||
|
|
||||||
|
|
||||||
|
def get_plaus_score(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
|
||||||
|
# TODO: Remove imgs from this function and only take it as input if debug is True
|
||||||
|
"""
|
||||||
|
Calculates the plausibility score based on the given inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imgs (torch.Tensor): The input images.
|
||||||
|
targets_out (torch.Tensor): The output targets.
|
||||||
|
attr (torch.Tensor): The attribute tensor.
|
||||||
|
debug (bool, optional): Whether to enable debug mode. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The plausibility score.
|
||||||
|
"""
|
||||||
|
# # if imgs is None:
|
||||||
|
# # imgs = torch.zeros_like(attr)
|
||||||
|
# # with torch.no_grad():
|
||||||
|
# target_inds = targets_out[:, 0].int()
|
||||||
|
# xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
|
||||||
|
# num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
|
||||||
|
# # num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
|
||||||
|
# xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
|
||||||
|
# co = xyxy_corners
|
||||||
|
# if corners:
|
||||||
|
# co = targets_out[:, 2:6].int()
|
||||||
|
# coords_map = torch.zeros_like(attr, dtype=torch.bool)
|
||||||
|
# # rows = np.arange(co.shape[0])
|
||||||
|
# x1, x2 = co[:,1], co[:,3]
|
||||||
|
# y1, y2 = co[:,0], co[:,2]
|
||||||
|
|
||||||
|
# for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
|
||||||
|
# coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
|
||||||
|
|
||||||
|
if torch.isnan(attr).any():
|
||||||
|
attr = torch.nan_to_num(attr, nan=0.0)
|
||||||
|
|
||||||
|
coords_map = get_bbox_map(targets_out, attr)
|
||||||
|
plaus_score = ((torch.sum((attr * coords_map))) / (torch.sum(attr)))
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
for i in range(len(coords_map)):
|
||||||
|
coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
|
||||||
|
test_bbox = torch.zeros_like(imgs[i])
|
||||||
|
test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
|
||||||
|
imshow(test_bbox, save_path='figs/test_bbox')
|
||||||
|
if imgs is None:
|
||||||
|
imgs = torch.zeros_like(attr)
|
||||||
|
imshow(imgs[i], save_path='figs/im0')
|
||||||
|
imshow(attr[i], save_path='figs/attr')
|
||||||
|
|
||||||
|
# with torch.no_grad():
|
||||||
|
# # att_select = attr[coords_map]
|
||||||
|
# att_select = attr * coords_map.to(torch.float32)
|
||||||
|
# att_total = attr
|
||||||
|
|
||||||
|
# IoU_num = torch.sum(att_select)
|
||||||
|
# IoU_denom = torch.sum(att_total)
|
||||||
|
|
||||||
|
# IoU_ = (IoU_num / IoU_denom)
|
||||||
|
# plaus_score = IoU_
|
||||||
|
|
||||||
|
# # plaus_score = ((torch.sum(attr[coords_map])) / (torch.sum(attr)))
|
||||||
|
|
||||||
|
return plaus_score
|
||||||
|
|
||||||
|
def get_attr_corners(targets_out, attr, debug=False, corners=False, imgs=None, eps = 1e-7):
|
||||||
|
# TODO: Remove imgs from this function and only take it as input if debug is True
|
||||||
|
"""
|
||||||
|
Calculates the plausibility score based on the given inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imgs (torch.Tensor): The input images.
|
||||||
|
targets_out (torch.Tensor): The output targets.
|
||||||
|
attr (torch.Tensor): The attribute tensor.
|
||||||
|
debug (bool, optional): Whether to enable debug mode. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The plausibility score.
|
||||||
|
"""
|
||||||
|
# if imgs is None:
|
||||||
|
# imgs = torch.zeros_like(attr)
|
||||||
|
# with torch.no_grad():
|
||||||
|
target_inds = targets_out[:, 0].int()
|
||||||
|
xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
|
||||||
|
num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
|
||||||
|
# num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
|
||||||
|
xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
|
||||||
|
co = xyxy_corners
|
||||||
|
if corners:
|
||||||
|
co = targets_out[:, 2:6].int()
|
||||||
|
coords_map = torch.zeros_like(attr, dtype=torch.bool)
|
||||||
|
# rows = np.arange(co.shape[0])
|
||||||
|
x1, x2 = co[:,1], co[:,3]
|
||||||
|
y1, y2 = co[:,0], co[:,2]
|
||||||
|
|
||||||
|
for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
|
||||||
|
coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
|
||||||
|
|
||||||
|
if torch.isnan(attr).any():
|
||||||
|
attr = torch.nan_to_num(attr, nan=0.0)
|
||||||
|
if debug:
|
||||||
|
for i in range(len(coords_map)):
|
||||||
|
coords_map3ch = torch.cat([coords_map[i][:1], coords_map[i][:1], coords_map[i][:1]], dim=0)
|
||||||
|
test_bbox = torch.zeros_like(imgs[i])
|
||||||
|
test_bbox[coords_map3ch] = imgs[i][coords_map3ch]
|
||||||
|
imshow(test_bbox, save_path='figs/test_bbox')
|
||||||
|
imshow(imgs[i], save_path='figs/im0')
|
||||||
|
imshow(attr[i], save_path='figs/attr')
|
||||||
|
|
||||||
|
# att_select = attr[coords_map]
|
||||||
|
# with torch.no_grad():
|
||||||
|
# IoU_num = (torch.sum(attr[coords_map]))
|
||||||
|
# IoU_denom = torch.sum(attr)
|
||||||
|
# IoU_ = (IoU_num / (IoU_denom))
|
||||||
|
|
||||||
|
# IoU_ = torch.max(attr[coords_map]) - torch.max(attr[~coords_map])
|
||||||
|
co = (xyxy_batch * num_pixels).int()
|
||||||
|
x1 = co[:,1] + 1
|
||||||
|
y1 = co[:,0] + 1
|
||||||
|
# with torch.no_grad():
|
||||||
|
attr_ = torch.sum(attr, 1, keepdim=True)
|
||||||
|
corners_attr = None #torch.zeros(len(xyxy_batch), 4, device=attr.device)
|
||||||
|
for ic in range(co.shape[0]):
|
||||||
|
attr0 = attr_[target_inds[ic], :,:x1[ic],:y1[ic]]
|
||||||
|
attr1 = attr_[target_inds[ic], :,:x1[ic],y1[ic]:]
|
||||||
|
attr2 = attr_[target_inds[ic], :,x1[ic]:,:y1[ic]]
|
||||||
|
attr3 = attr_[target_inds[ic], :,x1[ic]:,y1[ic]:]
|
||||||
|
|
||||||
|
x_0, y_0 = max_indices_2d(attr0[0])
|
||||||
|
x_1, y_1 = max_indices_2d(attr1[0])
|
||||||
|
x_2, y_2 = max_indices_2d(attr2[0])
|
||||||
|
x_3, y_3 = max_indices_2d(attr3[0])
|
||||||
|
|
||||||
|
y_1 += y1[ic]
|
||||||
|
x_2 += x1[ic]
|
||||||
|
x_3 += x1[ic]
|
||||||
|
y_3 += y1[ic]
|
||||||
|
|
||||||
|
max_corners = torch.cat([torch.min(x_0, x_2).unsqueeze(0) / attr_.shape[2],
|
||||||
|
torch.min(y_0, y_1).unsqueeze(0) / attr_.shape[3],
|
||||||
|
torch.max(x_1, x_3).unsqueeze(0) / attr_.shape[2],
|
||||||
|
torch.max(y_2, y_3).unsqueeze(0) / attr_.shape[3]])
|
||||||
|
if corners_attr is None:
|
||||||
|
corners_attr = max_corners
|
||||||
|
else:
|
||||||
|
corners_attr = torch.cat([corners_attr, max_corners], dim=0)
|
||||||
|
# corners_attr[ic] = max_corners
|
||||||
|
# corners_attr = attr[:,0,:4,0]
|
||||||
|
corners_attr = corners_attr.view(-1, 4)
|
||||||
|
# corners_attr = torch.stack(corners_attr, dim=0)
|
||||||
|
IoU_ = bbox_iou(corners_attr.T, xyxy_batch, x1y1x2y2=False, metric='CIoU')
|
||||||
|
plaus_score = IoU_.mean()
|
||||||
|
|
||||||
|
return plaus_score
|
||||||
|
|
||||||
|
def max_indices_2d(x_inp):
|
||||||
|
# values, indices = x.reshape(x.size(0), -1).max(dim=-1)
|
||||||
|
torch.max(x_inp,)
|
||||||
|
index = torch.argmax(x_inp)
|
||||||
|
x = index // x_inp.shape[1]
|
||||||
|
y = index % x_inp.shape[1]
|
||||||
|
# x, y = divmod(index.item(), x_inp.shape[1])
|
||||||
|
|
||||||
|
return torch.cat([x.unsqueeze(0), y.unsqueeze(0)])
|
||||||
|
|
||||||
|
|
||||||
|
def point_in_polygon(poly, grid):
|
||||||
|
# t0 = time.time()
|
||||||
|
num_points = poly.shape[0]
|
||||||
|
j = num_points - 1
|
||||||
|
oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
|
||||||
|
for i in range(num_points):
|
||||||
|
cond1 = (poly[i, 1] < grid[..., 1]) & (poly[j, 1] >= grid[..., 1])
|
||||||
|
cond2 = (poly[j, 1] < grid[..., 1]) & (poly[i, 1] >= grid[..., 1])
|
||||||
|
cond3 = (grid[..., 0] - poly[i, 0]) < (poly[j, 0] - poly[i, 0]) * (grid[..., 1] - poly[i, 1]) / (poly[j, 1] - poly[i, 1])
|
||||||
|
oddNodes = oddNodes ^ (cond1 | cond2) & cond3
|
||||||
|
j = i
|
||||||
|
# t1 = time.time()
|
||||||
|
# print(f'point in polygon time: {t1-t0}')
|
||||||
|
return oddNodes
|
||||||
|
|
||||||
|
def point_in_polygon_gpu(poly, grid):
|
||||||
|
num_points = poly.shape[0]
|
||||||
|
i = torch.arange(num_points)
|
||||||
|
j = (i - 1) % num_points
|
||||||
|
# Expand dimensions
|
||||||
|
# t0 = time.time()
|
||||||
|
poly_expanded = poly.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, grid.shape[0], grid.shape[0])
|
||||||
|
# t1 = time.time()
|
||||||
|
cond1 = (poly_expanded[i, 1] < grid[..., 1]) & (poly_expanded[j, 1] >= grid[..., 1])
|
||||||
|
cond2 = (poly_expanded[j, 1] < grid[..., 1]) & (poly_expanded[i, 1] >= grid[..., 1])
|
||||||
|
cond3 = (grid[..., 0] - poly_expanded[i, 0]) < (poly_expanded[j, 0] - poly_expanded[i, 0]) * (grid[..., 1] - poly_expanded[i, 1]) / (poly_expanded[j, 1] - poly_expanded[i, 1])
|
||||||
|
# t2 = time.time()
|
||||||
|
oddNodes = torch.zeros_like(grid[..., 0], dtype=torch.bool)
|
||||||
|
cond = (cond1 | cond2) & cond3
|
||||||
|
# t3 = time.time()
|
||||||
|
# efficiently perform xor using gpu and avoiding cpu as much as possible
|
||||||
|
c = []
|
||||||
|
while len(cond) > 1:
|
||||||
|
if len(cond) % 2 == 1: # odd number of elements
|
||||||
|
c.append(cond[-1])
|
||||||
|
cond = cond[:-1]
|
||||||
|
cond = torch.bitwise_xor(cond[:int(len(cond)/2)], cond[int(len(cond)/2):])
|
||||||
|
for c_ in c:
|
||||||
|
cond = torch.bitwise_xor(cond, c_)
|
||||||
|
oddNodes = cond
|
||||||
|
# t4 = time.time()
|
||||||
|
# for c in cond:
|
||||||
|
# oddNodes = oddNodes ^ c
|
||||||
|
# print(f'expand time: {t1-t0} | cond123 time: {t2-t1} | cond logic time: {t3-t2} | bitwise xor time: {t4-t3}')
|
||||||
|
# print(f'point in polygon time gpu: {t4-t0}')
|
||||||
|
# oddNodes = oddNodes ^ (cond1 | cond2) & cond3
|
||||||
|
return oddNodes
|
||||||
|
|
||||||
|
|
||||||
|
def bitmap_for_polygon(poly, h, w):
|
||||||
|
y = torch.arange(h).to(poly.device).float()
|
||||||
|
x = torch.arange(w).to(poly.device).float()
|
||||||
|
grid_y, grid_x = torch.meshgrid(y, x)
|
||||||
|
grid = torch.stack((grid_x, grid_y), dim=-1)
|
||||||
|
bitmap = point_in_polygon(poly, grid)
|
||||||
|
return bitmap.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def corners_coords(center_xywh):
|
||||||
|
center_x, center_y, w, h = center_xywh
|
||||||
|
x = center_x - w/2
|
||||||
|
y = center_y - h/2
|
||||||
|
return torch.tensor([x, y, x+w, y+h])
|
||||||
|
|
||||||
|
def corners_coords_batch(center_xywh):
|
||||||
|
center_x, center_y = center_xywh[:,0], center_xywh[:,1]
|
||||||
|
w, h = center_xywh[:,2], center_xywh[:,3]
|
||||||
|
x = center_x - w/2
|
||||||
|
y = center_y - h/2
|
||||||
|
return torch.stack([x, y, x+w, y+h], dim=1)
|
||||||
|
|
||||||
|
def normalize_batch(x):
|
||||||
|
"""
|
||||||
|
Normalize a batch of tensors along each channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Normalized tensor of the same shape as the input.
|
||||||
|
"""
|
||||||
|
mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
|
||||||
|
maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
|
||||||
|
for i in range(x.shape[0]):
|
||||||
|
mins[i] = x[i].min()
|
||||||
|
maxs[i] = x[i].max()
|
||||||
|
x_ = (x - mins) / (maxs - mins)
|
||||||
|
|
||||||
|
return x_
|
||||||
|
|
||||||
|
def normalize_batch_nonan(x):
|
||||||
|
"""
|
||||||
|
Normalize a batch of tensors along each channel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Normalized tensor of the same shape as the input.
|
||||||
|
"""
|
||||||
|
mins = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
|
||||||
|
maxs = torch.zeros((x.shape[0], *(1,)*len(x.shape[1:])), device=x.device)
|
||||||
|
x_ = torch.zeros_like(x)
|
||||||
|
for i in range(x.shape[0]):
|
||||||
|
if torch.all(x[i] == 0):
|
||||||
|
x_[i] = x[i]
|
||||||
|
else:
|
||||||
|
mins[i] = x[i].min()
|
||||||
|
maxs[i] = x[i].max()
|
||||||
|
x_[i] = (x[i] - mins[i]) / (maxs[i] - mins[i])
|
||||||
|
|
||||||
|
return x_
|
||||||
|
|
||||||
|
def get_detections(model_clone, img):
|
||||||
|
"""
|
||||||
|
Get detections from a model given an input image and targets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to use for detection.
|
||||||
|
img (torch.Tensor): The input image tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The detected bounding boxes.
|
||||||
|
"""
|
||||||
|
model_clone.eval() # Set model to evaluation mode
|
||||||
|
# Run inference
|
||||||
|
with torch.no_grad():
|
||||||
|
det_out, out = model_clone(img)
|
||||||
|
|
||||||
|
# model_.train()
|
||||||
|
del img
|
||||||
|
|
||||||
|
return det_out, out
|
||||||
|
|
||||||
|
def get_labels(det_out, imgs, targets, opt):
|
||||||
|
###################### Get predicted labels ######################
|
||||||
|
nb, _, height, width = imgs.shape # batch size, channels, height, width
|
||||||
|
targets_ = targets.clone()
|
||||||
|
targets_[:, 2:] = targets_[:, 2:] * torch.Tensor([width, height, width, height]).to(imgs.device) # to pixels
|
||||||
|
lb = [targets_[targets_[:, 0] == i, 1:] for i in range(nb)] if opt.save_hybrid else [] # for autolabelling
|
||||||
|
o = non_max_suppression(det_out, conf_thres=0.001, iou_thres=0.6, labels=lb, multi_label=True)
|
||||||
|
pred_labels = []
|
||||||
|
for si, pred in enumerate(o):
|
||||||
|
labels = targets_[targets_[:, 0] == si, 1:]
|
||||||
|
nl = len(labels)
|
||||||
|
predn = pred.clone()
|
||||||
|
# Get the indices that sort the values in column 5 in ascending order
|
||||||
|
sort_indices = torch.argsort(pred[:, 4], dim=0, descending=True)
|
||||||
|
# Apply the sorting indices to the tensor
|
||||||
|
sorted_pred = predn[sort_indices]
|
||||||
|
# Remove predictions with less than 0.1 confidence
|
||||||
|
n_conf = int(torch.sum(sorted_pred[:,4]>0.1)) + 1
|
||||||
|
sorted_pred = sorted_pred[:n_conf]
|
||||||
|
new_col = torch.ones((sorted_pred.shape[0], 1), device=imgs.device) * si
|
||||||
|
preds = torch.cat((new_col, sorted_pred[:, [5, 0, 1, 2, 3]]), dim=1)
|
||||||
|
preds[:, 2:] = xyxy2xywh(preds[:, 2:]) # xywh
|
||||||
|
gn = torch.tensor([width, height])[[1, 0, 1, 0]] # normalization gain whwh
|
||||||
|
preds[:, 2:] /= gn.to(imgs.device) # from pixels
|
||||||
|
pred_labels.append(preds)
|
||||||
|
pred_labels = torch.cat(pred_labels, 0).to(imgs.device)
|
||||||
|
|
||||||
|
return pred_labels
|
||||||
|
##################################################################
|
||||||
|
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
|
||||||
|
def get_center_coords(attr):
|
||||||
|
img_tensor = img_tensor / img_tensor.max()
|
||||||
|
|
||||||
|
# Define a brightness threshold
|
||||||
|
threshold = 0.95
|
||||||
|
|
||||||
|
# Create a binary mask of the bright pixels
|
||||||
|
mask = img_tensor > threshold
|
||||||
|
|
||||||
|
# Get the coordinates of the bright pixels
|
||||||
|
y_coords, x_coords = torch.where(mask)
|
||||||
|
|
||||||
|
# Calculate the centroid of the bright pixels
|
||||||
|
centroid_x = x_coords.float().mean().item()
|
||||||
|
centroid_y = y_coords.float().mean().item()
|
||||||
|
|
||||||
|
print(f'The central bright point is at ({centroid_x}, {centroid_y})')
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def get_distance_grids(attr, targets, imgs=None, focus_coeff=0.5, debug=False):
|
||||||
|
"""
|
||||||
|
Compute the distance grids from each pixel to the target coordinates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attr (torch.Tensor): Attribution maps.
|
||||||
|
targets (torch.Tensor): Target coordinates.
|
||||||
|
focus_coeff (float, optional): Focus coefficient, smaller means more focused. Defaults to 0.5.
|
||||||
|
debug (bool, optional): Whether to visualize debug information. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Distance grids.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Assign the height and width of the input tensor to variables
|
||||||
|
height, width = attr.shape[-1], attr.shape[-2]
|
||||||
|
|
||||||
|
# attr = torch.abs(attr) # Take absolute values of gradients
|
||||||
|
# attr = normalize_batch(attr) # Normalize attribution maps per image in batch
|
||||||
|
|
||||||
|
# Create a grid of indices
|
||||||
|
xx, yy = torch.stack(torch.meshgrid(torch.arange(height), torch.arange(width))).to(attr.device)
|
||||||
|
idx_grid = torch.stack((xx, yy), dim=-1).float()
|
||||||
|
|
||||||
|
# Expand the grid to match the batch size
|
||||||
|
idx_batch_grid = idx_grid.expand(attr.shape[0], -1, -1, -1)
|
||||||
|
|
||||||
|
# Initialize a list to store the distance grids
|
||||||
|
dist_grids_ = [[]] * attr.shape[0]
|
||||||
|
|
||||||
|
# Loop over batches
|
||||||
|
for j in range(attr.shape[0]):
|
||||||
|
# Get the rows where the first column is the current unique value
|
||||||
|
rows = targets[targets[:, 0] == j]
|
||||||
|
|
||||||
|
if len(rows) != 0:
|
||||||
|
# Create a tensor for the target coordinates
|
||||||
|
xy = rows[:,2:4] # y, x
|
||||||
|
# Flip the x and y coordinates and scale them to the image size
|
||||||
|
xy[:, 0], xy[:, 1] = xy[:, 1] * width, xy[:, 0] * height # y, x to x, y
|
||||||
|
xy_center = xy.unsqueeze(1).unsqueeze(1)#.requires_grad_(True)
|
||||||
|
|
||||||
|
# Compute the Euclidean distance from each pixel to the target coordinates
|
||||||
|
dists = torch.norm(idx_batch_grid[j].expand(len(xy_center), -1, -1, -1) - xy_center, dim=-1)
|
||||||
|
|
||||||
|
# Pick the closest distance to any target for each pixel
|
||||||
|
dist_grid_ = torch.min(dists, dim=0)[0].unsqueeze(0)
|
||||||
|
dist_grid = torch.cat([dist_grid_, dist_grid_, dist_grid_], dim=0) if attr.shape[1] == 3 else dist_grid_
|
||||||
|
else:
|
||||||
|
# Set grid to zero if no targets are present
|
||||||
|
dist_grid = torch.zeros_like(attr[j])
|
||||||
|
|
||||||
|
dist_grids_[j] = dist_grid
|
||||||
|
# Convert the list of distance grids to a tensor for faster computation
|
||||||
|
dist_grids = normalize_batch(torch.stack(dist_grids_)) ** focus_coeff
|
||||||
|
if torch.isnan(dist_grids).any():
|
||||||
|
dist_grids = torch.nan_to_num(dist_grids, nan=0.0)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
for i in range(len(dist_grids)):
|
||||||
|
if ((i % 8) == 0):
|
||||||
|
grid_show = torch.cat([dist_grids[i][:1], dist_grids[i][:1], dist_grids[i][:1]], dim=0)
|
||||||
|
imshow(grid_show, save_path='figs/dist_grids')
|
||||||
|
if imgs is None:
|
||||||
|
imgs = torch.zeros_like(attr)
|
||||||
|
imshow(imgs[i], save_path='figs/im0')
|
||||||
|
img_overlay = (overlay_mask(imgs[i], dist_grids[i][0], alpha = 0.75))
|
||||||
|
imshow(img_overlay, save_path='figs/dist_grid_overlay')
|
||||||
|
weighted_attr = (dist_grids[i] * attr[i])
|
||||||
|
imshow(weighted_attr, save_path='figs/weighted_attr')
|
||||||
|
imshow(attr[i], save_path='figs/attr')
|
||||||
|
|
||||||
|
return dist_grids
|
||||||
|
|
||||||
|
def attr_reg(attribution_map, distance_map):
|
||||||
|
|
||||||
|
# dist_attr = distance_map * attribution_map
|
||||||
|
dist_attr = torch.mean(distance_map * attribution_map)#, dim=(1, 2, 3))
|
||||||
|
# del distance_map, attribution_map
|
||||||
|
return dist_attr
|
||||||
|
|
||||||
|
def get_bbox_map(targets_out, attr, corners=False):
|
||||||
|
target_inds = targets_out[:, 0].int()
|
||||||
|
xyxy_batch = targets_out[:, 2:6]# * pre_gen_gains[out_num]
|
||||||
|
num_pixels = torch.tile(torch.tensor([attr.shape[2], attr.shape[3], attr.shape[2], attr.shape[3]], device=attr.device), (xyxy_batch.shape[0], 1))
|
||||||
|
# num_pixels = torch.tile(torch.tensor([1.0, 1.0, 1.0, 1.0], device=imgs.device), (xyxy_batch.shape[0], 1))
|
||||||
|
xyxy_corners = (corners_coords_batch(xyxy_batch) * num_pixels).int()
|
||||||
|
co = xyxy_corners
|
||||||
|
if corners:
|
||||||
|
co = targets_out[:, 2:6].int()
|
||||||
|
coords_map = torch.zeros_like(attr, dtype=torch.bool)
|
||||||
|
# rows = np.arange(co.shape[0])
|
||||||
|
x1, x2 = co[:,1], co[:,3]
|
||||||
|
y1, y2 = co[:,0], co[:,2]
|
||||||
|
|
||||||
|
for ic in range(co.shape[0]): # potential for speedup here with torch indexing instead of for loop
|
||||||
|
coords_map[target_inds[ic], :,x1[ic]:x2[ic],y1[ic]:y2[ic]] = True
|
||||||
|
|
||||||
|
bbox_map = coords_map.to(torch.float32)
|
||||||
|
|
||||||
|
return bbox_map
|
||||||
|
######################################## BCE #######################################
|
||||||
|
def get_plaus_loss(targets, attribution_map, opt, imgs=None, debug=False, only_loss=False):
|
||||||
|
# if imgs is None:
|
||||||
|
# imgs = torch.zeros_like(attribution_map)
|
||||||
|
# Calculate Plausibility IoU with attribution maps
|
||||||
|
# attribution_map.retains_grad = True
|
||||||
|
if not only_loss:
|
||||||
|
plaus_score = get_plaus_score(targets_out = targets, attr = attribution_map.clone().detach().requires_grad_(True), imgs = imgs)
|
||||||
|
else:
|
||||||
|
plaus_score = torch.tensor(0.0)
|
||||||
|
|
||||||
|
# attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
|
||||||
|
|
||||||
|
# Calculate distance regularization
|
||||||
|
distance_map = get_distance_grids(attribution_map, targets, imgs, opt.focus_coeff)
|
||||||
|
# distance_map = torch.ones_like(attribution_map)
|
||||||
|
|
||||||
|
if opt.dist_x_bbox:
|
||||||
|
bbox_map = get_bbox_map(targets, attribution_map).to(torch.bool)
|
||||||
|
distance_map[bbox_map] = 0.0
|
||||||
|
# distance_map = distance_map * (1 - bbox_map)
|
||||||
|
|
||||||
|
# Positive regularization term for incentivizing pixels near the target to have high attribution
|
||||||
|
dist_attr_pos = attr_reg(attribution_map, (1.0 - distance_map))
|
||||||
|
# Negative regularization term for incentivizing pixels far from the target to have low attribution
|
||||||
|
dist_attr_neg = attr_reg(attribution_map, distance_map)
|
||||||
|
# Calculate plausibility regularization term
|
||||||
|
# dist_reg = dist_attr_pos - dist_attr_neg
|
||||||
|
dist_reg = ((dist_attr_pos / torch.mean(attribution_map)) - (dist_attr_neg / torch.mean(attribution_map)))
|
||||||
|
# dist_reg = torch.mean((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3))) - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3))))
|
||||||
|
# dist_reg = (torch.mean(torch.exp((dist_attr_pos / torch.mean(attribution_map, dim=(1, 2, 3)))) + \
|
||||||
|
# torch.exp(1 - (dist_attr_neg / torch.mean(attribution_map, dim=(1, 2, 3)))))) \
|
||||||
|
# / 2.5
|
||||||
|
|
||||||
|
if opt.bbox_coeff != 0.0:
|
||||||
|
bbox_map = get_bbox_map(targets, attribution_map)
|
||||||
|
attr_bbox_pos = attr_reg(attribution_map, bbox_map)
|
||||||
|
attr_bbox_neg = attr_reg(attribution_map, (1.0 - bbox_map))
|
||||||
|
bbox_reg = attr_bbox_pos - attr_bbox_neg
|
||||||
|
# bbox_reg = (attr_bbox_pos / torch.mean(attribution_map)) - (attr_bbox_neg / torch.mean(attribution_map))
|
||||||
|
else:
|
||||||
|
bbox_reg = 0.0
|
||||||
|
|
||||||
|
bbox_map = get_bbox_map(targets, attribution_map)
|
||||||
|
plaus_score = ((torch.sum((attribution_map * bbox_map))) / (torch.sum(attribution_map)))
|
||||||
|
# iou_loss = (1.0 - plaus_score)
|
||||||
|
|
||||||
|
if not opt.dist_reg_only:
|
||||||
|
dist_reg_loss = (((1.0 + dist_reg) / 2.0))
|
||||||
|
plaus_reg = (plaus_score * opt.iou_coeff) + \
|
||||||
|
(((dist_reg_loss * opt.dist_coeff) + \
|
||||||
|
(bbox_reg * opt.bbox_coeff))\
|
||||||
|
# ((((((1.0 + dist_reg) / 2.0) - 1.0) * opt.dist_coeff) + ((((1.0 + bbox_reg) / 2.0) - 1.0) * opt.bbox_coeff))\
|
||||||
|
# / (plaus_score) \
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
plaus_reg = (((1.0 + dist_reg) / 2.0))
|
||||||
|
# plaus_reg = dist_reg
|
||||||
|
# Calculate plausibility loss
|
||||||
|
plaus_loss = (1 - plaus_reg) * opt.pgt_coeff
|
||||||
|
# plaus_loss = (plaus_reg) * opt.pgt_coeff
|
||||||
|
if only_loss:
|
||||||
|
return plaus_loss
|
||||||
|
if not debug:
|
||||||
|
return plaus_loss, (plaus_score, dist_reg, plaus_reg,)
|
||||||
|
else:
|
||||||
|
return plaus_loss, (plaus_score, dist_reg, plaus_reg,), distance_map
|
||||||
|
|
||||||
|
####################################################################################
|
||||||
|
#### ALL FUNCTIONS BELOW ARE DEPRECIATED AND WILL BE REMOVED IN FUTURE VERSIONS ####
|
||||||
|
####################################################################################
|
||||||
|
|
||||||
|
def generate_vanilla_grad(model, input_tensor, loss_func = None,
|
||||||
|
targets_list=None, targets=None, metric=None, out_num = 1,
|
||||||
|
n_max_labels=3, norm=True, abs=True, grayscale=True,
|
||||||
|
class_specific_attr = True, device='cpu'):
|
||||||
|
"""
|
||||||
|
Generate vanilla gradients for the given model and input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to generate gradients for.
|
||||||
|
input_tensor (torch.Tensor): The input tensor for which gradients are computed.
|
||||||
|
loss_func (callable, optional): The loss function to compute gradients with respect to. Defaults to None.
|
||||||
|
targets_list (list, optional): The list of target tensors. Defaults to None.
|
||||||
|
metric (callable, optional): The metric function to evaluate the loss. Defaults to None.
|
||||||
|
out_num (int, optional): The index of the output tensor to compute gradients with respect to. Defaults to 1.
|
||||||
|
n_max_labels (int, optional): The maximum number of labels to consider. Defaults to 3.
|
||||||
|
norm (bool, optional): Whether to normalize the attribution map. Defaults to True.
|
||||||
|
abs (bool, optional): Whether to take the absolute values of gradients. Defaults to True.
|
||||||
|
grayscale (bool, optional): Whether to convert the attribution map to grayscale. Defaults to True.
|
||||||
|
class_specific_attr (bool, optional): Whether to compute class-specific attribution maps. Defaults to True.
|
||||||
|
device (str, optional): The device to use for computation. Defaults to 'cpu'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The generated vanilla gradients.
|
||||||
|
"""
|
||||||
|
# Set model.train() at the beginning and revert back to original mode (model.eval() or model.train()) at the end
|
||||||
|
train_mode = model.training
|
||||||
|
if not train_mode:
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
input_tensor.requires_grad = True # Set requires_grad attribute of tensor. Important for computing gradients
|
||||||
|
model.zero_grad() # Zero gradients
|
||||||
|
inpt = input_tensor
|
||||||
|
# Forward pass
|
||||||
|
train_out = model(inpt) # training outputs (no inference outputs in train mode)
|
||||||
|
|
||||||
|
# train_out[1] = torch.Size([4, 3, 80, 80, 7]) HxWx(#anchorxC) cls (class probabilities)
|
||||||
|
# train_out[0] = torch.Size([4, 3, 160, 160, 7]) HxWx(#anchorx4) box or reg (location and scaling)
|
||||||
|
# train_out[2] = torch.Size([4, 3, 40, 40, 7]) HxWx(#anchorx1) obj (objectness score or confidence)
|
||||||
|
|
||||||
|
if class_specific_attr:
|
||||||
|
n_attr_list, index_classes = [], []
|
||||||
|
for i in range(len(input_tensor)):
|
||||||
|
if len(targets_list[i]) > n_max_labels:
|
||||||
|
targets_list[i] = targets_list[i][:n_max_labels]
|
||||||
|
if targets_list[i].numel() != 0:
|
||||||
|
# unique_classes = torch.unique(targets_list[i][:,1])
|
||||||
|
class_numbers = targets_list[i][:,1]
|
||||||
|
index_classes.append([[0, 1, 2, 3, 4, int(uc)] for uc in class_numbers])
|
||||||
|
num_attrs = len(targets_list[i])
|
||||||
|
# index_classes.append([0, 1, 2, 3, 4] + [int(uc + 5) for uc in unique_classes])
|
||||||
|
# num_attrs = 1 #len(unique_classes)# if loss_func else len(targets_list[i])
|
||||||
|
n_attr_list.append(num_attrs)
|
||||||
|
else:
|
||||||
|
index_classes.append([0, 1, 2, 3, 4])
|
||||||
|
n_attr_list.append(0)
|
||||||
|
|
||||||
|
targets_list_filled = [targ.clone().detach() for targ in targets_list]
|
||||||
|
labels_len = [len(targets_list[ih]) for ih in range(len(targets_list))]
|
||||||
|
max_labels = np.max(labels_len)
|
||||||
|
max_index = np.argmax(labels_len)
|
||||||
|
for i in range(len(targets_list)):
|
||||||
|
# targets_list_filled[i] = targets_list[i]
|
||||||
|
if len(targets_list_filled[i]) < max_labels:
|
||||||
|
tlist = [targets_list_filled[i]] * math.ceil(max_labels / len(targets_list_filled[i]))
|
||||||
|
targets_list_filled[i] = torch.cat(tlist)[:max_labels].unsqueeze(0)
|
||||||
|
else:
|
||||||
|
targets_list_filled[i] = targets_list_filled[i].unsqueeze(0)
|
||||||
|
for i in range(len(targets_list_filled)-1,-1,-1):
|
||||||
|
if targets_list_filled[i].numel() == 0:
|
||||||
|
targets_list_filled.pop(i)
|
||||||
|
targets_list_filled = torch.cat(targets_list_filled)
|
||||||
|
|
||||||
|
n_img_attrs = len(input_tensor) if class_specific_attr else 1
|
||||||
|
n_img_attrs = 1 if loss_func else n_img_attrs
|
||||||
|
|
||||||
|
attrs_batch = []
|
||||||
|
for i_batch in range(n_img_attrs):
|
||||||
|
if loss_func and class_specific_attr:
|
||||||
|
i_batch = max_index
|
||||||
|
# inpt = input_tensor[i_batch].unsqueeze(0)
|
||||||
|
# ##################################################################
|
||||||
|
# model.zero_grad() # Zero gradients
|
||||||
|
# train_out = model(inpt) # training outputs (no inference outputs in train mode)
|
||||||
|
# ##################################################################
|
||||||
|
n_label_attrs = n_attr_list[i_batch] if class_specific_attr else 1
|
||||||
|
n_label_attrs = 1 if not class_specific_attr else n_label_attrs
|
||||||
|
attrs_img = []
|
||||||
|
for i_attr in range(n_label_attrs):
|
||||||
|
if loss_func is None:
|
||||||
|
grad_wrt = train_out[out_num]
|
||||||
|
if class_specific_attr:
|
||||||
|
grad_wrt = train_out[out_num][:,:,:,:,index_classes[i_batch][i_attr]]
|
||||||
|
grad_wrt_outputs = torch.ones_like(grad_wrt)
|
||||||
|
else:
|
||||||
|
# if class_specific_attr:
|
||||||
|
# targets = targets_list[:][i_attr]
|
||||||
|
# n_targets = len(targets_list[i_batch])
|
||||||
|
if class_specific_attr:
|
||||||
|
target_indiv = targets_list_filled[:,i_attr] # batch image input
|
||||||
|
else:
|
||||||
|
target_indiv = targets
|
||||||
|
# target_indiv = targets_list[i_batch][i_attr].unsqueeze(0) # single image input
|
||||||
|
# target_indiv[:,0] = 0 # this indicates the batch index of the target, should be 0 since we are only doing one image at a time
|
||||||
|
|
||||||
|
try:
|
||||||
|
loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric) # loss scaled by batch_size
|
||||||
|
except:
|
||||||
|
target_indiv = target_indiv.to(device)
|
||||||
|
inpt = inpt.to(device)
|
||||||
|
for tro in train_out:
|
||||||
|
tro = tro.to(device)
|
||||||
|
print("Error in loss function, trying again with device specified")
|
||||||
|
loss, loss_items = loss_func(train_out, target_indiv, inpt, metric=metric)
|
||||||
|
grad_wrt = loss
|
||||||
|
grad_wrt_outputs = None
|
||||||
|
|
||||||
|
model.zero_grad() # Zero gradients
|
||||||
|
gradients = torch.autograd.grad(grad_wrt, inpt,
|
||||||
|
grad_outputs=grad_wrt_outputs,
|
||||||
|
retain_graph=True,
|
||||||
|
# create_graph=True, # Create graph to allow for higher order derivatives but slows down computation significantly
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert gradients to numpy array and back to ensure full separation from graph
|
||||||
|
# attribution_map = torch.tensor(torch.sum(gradients[0], 1, keepdim=True).clone().detach().cpu().numpy())
|
||||||
|
attribution_map = gradients[0]#.clone().detach() # without converting to numpy
|
||||||
|
|
||||||
|
if grayscale: # Convert to grayscale, saves vram and computation time for plaus_eval
|
||||||
|
attribution_map = torch.sum(attribution_map, 1, keepdim=True)
|
||||||
|
if abs:
|
||||||
|
attribution_map = torch.abs(attribution_map) # Take absolute values of gradients
|
||||||
|
if norm:
|
||||||
|
attribution_map = normalize_batch(attribution_map) # Normalize attribution maps per image in batch
|
||||||
|
attrs_img.append(attribution_map)
|
||||||
|
if len(attrs_img) == 0:
|
||||||
|
attrs_batch.append((torch.zeros_like(inpt).unsqueeze(0)).to(device))
|
||||||
|
else:
|
||||||
|
attrs_batch.append(torch.stack(attrs_img).to(device))
|
||||||
|
|
||||||
|
# out_attr = torch.tensor(attribution_map).unsqueeze(0).to(device) if ((loss_func) or (not class_specific_attr)) else torch.stack(attrs_batch).to(device)
|
||||||
|
# out_attr = [attrs_batch[0]] * len(input_tensor) if ((loss_func) or (not class_specific_attr)) else attrs_batch
|
||||||
|
out_attr = attrs_batch
|
||||||
|
# Set model back to original mode
|
||||||
|
if not train_mode:
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
return out_attr
|
||||||
|
|
||||||
|
class RVNonLinearFunc(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Custom Bayesian ReLU activation function for random variables.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
def __init__(self, func):
|
||||||
|
super(RVNonLinearFunc, self).__init__()
|
||||||
|
self.func = func
|
||||||
|
|
||||||
|
def forward(self, mu_in, Sigma_in):
|
||||||
|
"""
|
||||||
|
Forward pass of the Bayesian ReLU activation function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mu_in (torch.Tensor): A tensor of shape (batch_size, input_size),
|
||||||
|
representing the mean input to the ReLU activation function.
|
||||||
|
Sigma_in (torch.Tensor): A tensor of shape (batch_size, input_size, input_size),
|
||||||
|
representing the covariance input to the ReLU activation function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: A tuple of two tensors,
|
||||||
|
including the mean of the output and the covariance of the output.
|
||||||
|
"""
|
||||||
|
# Collect stats
|
||||||
|
batch_size = mu_in.size(0)
|
||||||
|
|
||||||
|
# Mean
|
||||||
|
mu_out = self.func(mu_in)
|
||||||
|
|
||||||
|
# Compute the derivative of the ReLU activation function with respect to the input mean
|
||||||
|
gradi = torch.autograd.grad(mu_out, mu_in, grad_outputs=torch.ones_like(mu_out), create_graph=True)[0].view(batch_size,-1)
|
||||||
|
|
||||||
|
# add an extra dimension to gradi at position 2 and 1
|
||||||
|
grad1 = gradi.unsqueeze(dim=2)
|
||||||
|
grad2 = gradi.unsqueeze(dim=1)
|
||||||
|
|
||||||
|
# compute the outer product of grad1 and grad2
|
||||||
|
outer_product = torch.bmm(grad1, grad2)
|
||||||
|
|
||||||
|
# element-wise multiply Sigma_in with the outer product
|
||||||
|
# and return the result
|
||||||
|
Sigma_out = torch.mul(Sigma_in, outer_product)
|
||||||
|
|
||||||
|
return mu_out, Sigma_out
|
||||||
|
|
154
ultralytics/utils/plot_functs.py
Normal file
154
ultralytics/utils/plot_functs.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class Subplots:
|
||||||
|
def __init__(self, figsize = (40, 5)):
|
||||||
|
self.fig = plt.figure(figsize=figsize)
|
||||||
|
|
||||||
|
def plot_img_list(self, img_list, savedir='figs/test',
|
||||||
|
nrows = 1, rownum = 0,
|
||||||
|
hold = False, coltitles=[], rowtitle=''):
|
||||||
|
|
||||||
|
for i, img in enumerate(img_list):
|
||||||
|
try:
|
||||||
|
npimg = img.clone().detach().cpu().numpy()
|
||||||
|
except:
|
||||||
|
npimg = img
|
||||||
|
tpimg = np.transpose(npimg, (1, 2, 0))
|
||||||
|
lenrow = int((len(img_list)))
|
||||||
|
ax = self.fig.add_subplot(nrows, lenrow, i+1+(rownum*lenrow))
|
||||||
|
if len(coltitles) > i:
|
||||||
|
ax.set_title(coltitles[i])
|
||||||
|
if i == 0:
|
||||||
|
ax.annotate(rowtitle, xy=((-0.06 * len(rowtitle)), 0.4),# xytext=(-ax.yaxis.labelpad - pad, 0),
|
||||||
|
xycoords='axes fraction', textcoords='offset points',
|
||||||
|
size='large', ha='center', va='baseline')
|
||||||
|
# ax.set_ylabel(rowtitle, rotation=90)
|
||||||
|
ax.imshow(tpimg)
|
||||||
|
ax.axis('off')
|
||||||
|
|
||||||
|
if not hold:
|
||||||
|
self.fig.tight_layout()
|
||||||
|
plt.savefig(f'{savedir}.png')
|
||||||
|
plt.clf()
|
||||||
|
plt.close('all')
|
||||||
|
|
||||||
|
|
||||||
|
def VisualizeNumpyImageGrayscale(image_3d):
|
||||||
|
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
|
||||||
|
"""
|
||||||
|
vmin = np.min(image_3d)
|
||||||
|
image_2d = image_3d - vmin
|
||||||
|
vmax = np.max(image_2d)
|
||||||
|
return (image_2d / vmax)
|
||||||
|
|
||||||
|
def normalize_numpy(image_3d):
|
||||||
|
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
|
||||||
|
"""
|
||||||
|
vmin = np.min(image_3d)
|
||||||
|
image_2d = image_3d - vmin
|
||||||
|
vmax = np.max(image_2d)
|
||||||
|
return (image_2d / vmax)
|
||||||
|
|
||||||
|
# def normalize_tensor(image_3d):
|
||||||
|
# r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
|
||||||
|
# """
|
||||||
|
# vmin = torch.min(image_3d)
|
||||||
|
# image_2d = image_3d - vmin
|
||||||
|
# vmax = torch.max(image_2d)
|
||||||
|
# return (image_2d / vmax)
|
||||||
|
|
||||||
|
def normalize_tensor(image_3d):
|
||||||
|
r"""Returns a 3D tensor as a grayscale normalized between 0 and 1 2D tensor.
|
||||||
|
"""
|
||||||
|
image_2d = (image_3d - torch.min(image_3d))
|
||||||
|
return (image_2d / torch.max(image_2d))
|
||||||
|
|
||||||
|
def format_img(img_):
|
||||||
|
np_img = img_.numpy()
|
||||||
|
tp_img = np.transpose(np_img, (1, 2, 0))
|
||||||
|
return tp_img
|
||||||
|
|
||||||
|
def imshow(img, save_path=None):
|
||||||
|
try:
|
||||||
|
npimg = img.clone().detach().cpu().numpy()
|
||||||
|
except:
|
||||||
|
npimg = img
|
||||||
|
tpimg = np.transpose(npimg, (1, 2, 0))
|
||||||
|
plt.imshow(tpimg)
|
||||||
|
# plt.axis('off')
|
||||||
|
plt.tight_layout()
|
||||||
|
if save_path != None:
|
||||||
|
plt.savefig(str(str(save_path) + ".png"))
|
||||||
|
#plt.show()a
|
||||||
|
|
||||||
|
def imshow_img(img, imsave_path):
|
||||||
|
# works for tensors and numpy arrays
|
||||||
|
try:
|
||||||
|
npimg = VisualizeNumpyImageGrayscale(img.numpy())
|
||||||
|
except:
|
||||||
|
npimg = VisualizeNumpyImageGrayscale(img)
|
||||||
|
npimg = np.transpose(npimg, (2, 0, 1))
|
||||||
|
imshow(npimg, save_path=imsave_path)
|
||||||
|
print("Saving image as ", imsave_path)
|
||||||
|
|
||||||
|
def returnGrad(img, labels, model, compute_loss, loss_metric, augment=None, device = 'cpu'):
|
||||||
|
model.train()
|
||||||
|
model.to(device)
|
||||||
|
img = img.to(device)
|
||||||
|
img.requires_grad_(True)
|
||||||
|
labels.to(device).requires_grad_(True)
|
||||||
|
model.requires_grad_(True)
|
||||||
|
cuda = device.type != 'cpu'
|
||||||
|
scaler = amp.GradScaler(enabled=cuda)
|
||||||
|
pred = model(img)
|
||||||
|
# out, train_out = model(img, augment=augment) # inference and training outputs
|
||||||
|
loss, loss_items = compute_loss(pred, labels, metric=loss_metric)#[1][:3] # box, obj, cls
|
||||||
|
# loss = criterion(pred, torch.tensor([int(torch.max(pred[0], 0)[1])]).to(device))
|
||||||
|
# loss = torch.sum(loss).requires_grad_(True)
|
||||||
|
|
||||||
|
with torch.autograd.set_detect_anomaly(True):
|
||||||
|
scaler.scale(loss).backward(inputs=img)
|
||||||
|
# loss.backward()
|
||||||
|
|
||||||
|
# S_c = torch.max(pred[0].data, 0)[0]
|
||||||
|
Sc_dx = img.grad
|
||||||
|
model.eval()
|
||||||
|
Sc_dx = torch.tensor(Sc_dx, dtype=torch.float32)
|
||||||
|
return Sc_dx
|
||||||
|
|
||||||
|
def calculate_snr(img, attr, dB=True):
|
||||||
|
try:
|
||||||
|
img_np = img.detach().cpu().numpy()
|
||||||
|
attr_np = attr.detach().cpu().numpy()
|
||||||
|
except:
|
||||||
|
img_np = img
|
||||||
|
attr_np = attr
|
||||||
|
|
||||||
|
# Calculate the signal power
|
||||||
|
signal_power = np.mean(img_np**2)
|
||||||
|
|
||||||
|
# Calculate the noise power
|
||||||
|
noise_power = np.mean(attr_np**2)
|
||||||
|
|
||||||
|
if dB == True:
|
||||||
|
# Calculate SNR in dB
|
||||||
|
snr = 10 * np.log10(signal_power / noise_power)
|
||||||
|
else:
|
||||||
|
# Calculate SNR
|
||||||
|
snr = signal_power / noise_power
|
||||||
|
|
||||||
|
return snr
|
||||||
|
|
||||||
|
def overlay_mask(img, mask, colormap: str = "jet", alpha: float = 0.7):
|
||||||
|
|
||||||
|
cmap = plt.get_cmap(colormap)
|
||||||
|
npmask = np.array(mask.clone().detach().cpu().squeeze(0))
|
||||||
|
# cmpmask = ((255 * cmap(npmask)[:, :, :3]).astype(np.uint8)).transpose((2, 0, 1))
|
||||||
|
cmpmask = (cmap(npmask)[:, :, :3]).transpose((2, 0, 1))
|
||||||
|
overlayed_imgnp = ((alpha * (np.asarray(img.clone().detach().cpu())) + (1 - alpha) * cmpmask))
|
||||||
|
overlayed_tensor = torch.tensor(overlayed_imgnp, device=img.device)
|
||||||
|
|
||||||
|
return overlayed_tensor
|
@ -717,7 +717,7 @@ def plot_images(
|
|||||||
):
|
):
|
||||||
"""Plot image grid with labels."""
|
"""Plot image grid with labels."""
|
||||||
if isinstance(images, torch.Tensor):
|
if isinstance(images, torch.Tensor):
|
||||||
images = images.cpu().float().numpy()
|
images = images.detach().cpu().float().numpy()
|
||||||
if isinstance(cls, torch.Tensor):
|
if isinstance(cls, torch.Tensor):
|
||||||
cls = cls.cpu().numpy()
|
cls = cls.cpu().numpy()
|
||||||
if isinstance(bboxes, torch.Tensor):
|
if isinstance(bboxes, torch.Tensor):
|
||||||
@ -837,6 +837,144 @@ def plot_images(
|
|||||||
if on_plot:
|
if on_plot:
|
||||||
on_plot(fname)
|
on_plot(fname)
|
||||||
|
|
||||||
|
@threaded
|
||||||
|
def plot_gradients(
|
||||||
|
images,
|
||||||
|
batch_idx,
|
||||||
|
cls,
|
||||||
|
bboxes=np.zeros(0, dtype=np.float32),
|
||||||
|
confs=None,
|
||||||
|
masks=np.zeros(0, dtype=np.uint8),
|
||||||
|
kpts=np.zeros((0, 51), dtype=np.float32),
|
||||||
|
paths=None,
|
||||||
|
fname="images.jpg",
|
||||||
|
names=None,
|
||||||
|
on_plot=None,
|
||||||
|
max_subplots=16,
|
||||||
|
save=True,
|
||||||
|
conf_thres=0.25,
|
||||||
|
):
|
||||||
|
"""Plot image grid with labels."""
|
||||||
|
if isinstance(images, torch.Tensor):
|
||||||
|
images = images.detach().cpu().float().numpy()
|
||||||
|
if isinstance(cls, torch.Tensor):
|
||||||
|
cls = cls.cpu().numpy()
|
||||||
|
if isinstance(bboxes, torch.Tensor):
|
||||||
|
bboxes = bboxes.cpu().numpy()
|
||||||
|
if isinstance(masks, torch.Tensor):
|
||||||
|
masks = masks.cpu().numpy().astype(int)
|
||||||
|
if isinstance(kpts, torch.Tensor):
|
||||||
|
kpts = kpts.cpu().numpy()
|
||||||
|
if isinstance(batch_idx, torch.Tensor):
|
||||||
|
batch_idx = batch_idx.cpu().numpy()
|
||||||
|
|
||||||
|
max_size = 1920 # max image size
|
||||||
|
bs, _, h, w = images.shape # batch size, _, height, width
|
||||||
|
bs = min(bs, max_subplots) # limit plot images
|
||||||
|
ns = np.ceil(bs**0.5) # number of subplots (square)
|
||||||
|
if np.max(images[0]) <= 1:
|
||||||
|
images *= 255 # de-normalise (optional)
|
||||||
|
|
||||||
|
# Build Image
|
||||||
|
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
||||||
|
for i in range(bs):
|
||||||
|
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||||
|
mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
|
||||||
|
|
||||||
|
# Resize (optional)
|
||||||
|
scale = max_size / ns / max(h, w)
|
||||||
|
if scale < 1:
|
||||||
|
h = math.ceil(scale * h)
|
||||||
|
w = math.ceil(scale * w)
|
||||||
|
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
|
||||||
|
|
||||||
|
# Annotate
|
||||||
|
fs = int((h + w) * ns * 0.01) # font size
|
||||||
|
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
|
||||||
|
for i in range(bs):
|
||||||
|
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||||
|
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
||||||
|
if paths:
|
||||||
|
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
||||||
|
if len(cls) > 0:
|
||||||
|
idx = batch_idx == i
|
||||||
|
classes = cls[idx].astype("int")
|
||||||
|
labels = confs is None
|
||||||
|
|
||||||
|
if len(bboxes):
|
||||||
|
boxes = bboxes[idx]
|
||||||
|
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
|
||||||
|
is_obb = boxes.shape[-1] == 5 # xywhr
|
||||||
|
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
|
||||||
|
if len(boxes):
|
||||||
|
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
|
||||||
|
boxes[..., 0::2] *= w # scale to pixels
|
||||||
|
boxes[..., 1::2] *= h
|
||||||
|
elif scale < 1: # absolute coords need scale if image scales
|
||||||
|
boxes[..., :4] *= scale
|
||||||
|
boxes[..., 0::2] += x
|
||||||
|
boxes[..., 1::2] += y
|
||||||
|
for j, box in enumerate(boxes.astype(np.int64).tolist()):
|
||||||
|
c = classes[j]
|
||||||
|
color = colors(c)
|
||||||
|
c = names.get(c, c) if names else c
|
||||||
|
if labels or conf[j] > conf_thres:
|
||||||
|
label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
|
||||||
|
annotator.box_label(box, label, color=color, rotated=is_obb)
|
||||||
|
|
||||||
|
elif len(classes):
|
||||||
|
for c in classes:
|
||||||
|
color = colors(c)
|
||||||
|
c = names.get(c, c) if names else c
|
||||||
|
annotator.text((x, y), f"{c}", txt_color=color, box_style=True)
|
||||||
|
|
||||||
|
# Plot keypoints
|
||||||
|
if len(kpts):
|
||||||
|
kpts_ = kpts[idx].copy()
|
||||||
|
if len(kpts_):
|
||||||
|
if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01
|
||||||
|
kpts_[..., 0] *= w # scale to pixels
|
||||||
|
kpts_[..., 1] *= h
|
||||||
|
elif scale < 1: # absolute coords need scale if image scales
|
||||||
|
kpts_ *= scale
|
||||||
|
kpts_[..., 0] += x
|
||||||
|
kpts_[..., 1] += y
|
||||||
|
for j in range(len(kpts_)):
|
||||||
|
if labels or conf[j] > conf_thres:
|
||||||
|
annotator.kpts(kpts_[j])
|
||||||
|
|
||||||
|
# Plot masks
|
||||||
|
if len(masks):
|
||||||
|
if idx.shape[0] == masks.shape[0]: # overlap_masks=False
|
||||||
|
image_masks = masks[idx]
|
||||||
|
else: # overlap_masks=True
|
||||||
|
image_masks = masks[[i]] # (1, 640, 640)
|
||||||
|
nl = idx.sum()
|
||||||
|
index = np.arange(nl).reshape((nl, 1, 1)) + 1
|
||||||
|
image_masks = np.repeat(image_masks, nl, axis=0)
|
||||||
|
image_masks = np.where(image_masks == index, 1.0, 0.0)
|
||||||
|
|
||||||
|
im = np.asarray(annotator.im).copy()
|
||||||
|
for j in range(len(image_masks)):
|
||||||
|
if labels or conf[j] > conf_thres:
|
||||||
|
color = colors(classes[j])
|
||||||
|
mh, mw = image_masks[j].shape
|
||||||
|
if mh != h or mw != w:
|
||||||
|
mask = image_masks[j].astype(np.uint8)
|
||||||
|
mask = cv2.resize(mask, (w, h))
|
||||||
|
mask = mask.astype(bool)
|
||||||
|
else:
|
||||||
|
mask = image_masks[j].astype(bool)
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
im[y : y + h, x : x + w, :][mask] = (
|
||||||
|
im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
|
||||||
|
)
|
||||||
|
annotator.fromarray(im)
|
||||||
|
if not save:
|
||||||
|
return np.asarray(annotator.im)
|
||||||
|
annotator.im.save(fname) # save
|
||||||
|
if on_plot:
|
||||||
|
on_plot(fname)
|
||||||
|
|
||||||
@plt_settings()
|
@plt_settings()
|
||||||
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
|
def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user