2024-10-23 20:33:28 -04:00

75 lines
2.8 KiB
Python

# Ultralytics YOLO 🚀, AGPL-3.0 license
from copy import copy
from ultralytics.models import yolo
from ultralytics.nn.tasks import SegmentationModel, DetectionModel
from ultralytics.utils import DEFAULT_CFG, RANK
# from ultralytics.utils import yaml_load, IterableSimpleNamespace, ROOT
from ultralytics.utils.plotting import plot_images, plot_results
from ultralytics.models.yolov10.model import YOLOv10PGTDetectionModel
from ultralytics.models.yolov10.val import YOLOv10PGTDetectionValidator
# # Default configuration
# DEFAULT_CFG_DICT = yaml_load(ROOT / "cfg/pgt_train.yaml")
# for k, v in DEFAULT_CFG_DICT.items():
# if isinstance(v, str) and v.lower() == "none":
# DEFAULT_CFG_DICT[k] = None
# DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
# DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
class PGTSegmentationTrainer(yolo.detect.PGTDetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on a segmentation model.
Example:
```python
from ultralytics.models.yolo.segment import SegmentationTrainer
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3)
trainer = SegmentationTrainer(overrides=args)
trainer.train()
```
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a SegmentationTrainer object with given arguments."""
if overrides is None:
overrides = {}
overrides["task"] = "segment"
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return SegmentationModel initialized with specified config and weights."""
model = YOLOv10PGTDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Return an instance of SegmentationValidator for validation of YOLO model."""
self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", "pgt_loss",
return YOLOv10PGTDetectionValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def plot_training_samples(self, batch, ni):
"""Creates a plot of training sample images with labels and box coordinates."""
plot_images(
batch["img"],
batch["batch_idx"],
batch["cls"].squeeze(-1),
batch["bboxes"],
masks=batch["masks"],
paths=batch["im_file"],
fname=self.save_dir / f"train_batch{ni}.jpg",
on_plot=self.on_plot,
)
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png