# Ultralytics YOLO 🚀, AGPL-3.0 license from copy import copy from ultralytics.models import yolo from ultralytics.nn.tasks import SegmentationModel, DetectionModel from ultralytics.utils import DEFAULT_CFG, RANK # from ultralytics.utils import yaml_load, IterableSimpleNamespace, ROOT from ultralytics.utils.plotting import plot_images, plot_results from ultralytics.models.yolov10.model import YOLOv10PGTDetectionModel from ultralytics.models.yolov10.val import YOLOv10PGTDetectionValidator # # Default configuration # DEFAULT_CFG_DICT = yaml_load(ROOT / "cfg/pgt_train.yaml") # for k, v in DEFAULT_CFG_DICT.items(): # if isinstance(v, str) and v.lower() == "none": # DEFAULT_CFG_DICT[k] = None # DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys() # DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT) class PGTSegmentationTrainer(yolo.detect.PGTDetectionTrainer): """ A class extending the DetectionTrainer class for training based on a segmentation model. Example: ```python from ultralytics.models.yolo.segment import SegmentationTrainer args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml', epochs=3) trainer = SegmentationTrainer(overrides=args) trainer.train() ``` """ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): """Initialize a SegmentationTrainer object with given arguments.""" if overrides is None: overrides = {} overrides["task"] = "segment" super().__init__(cfg, overrides, _callbacks) def get_model(self, cfg=None, weights=None, verbose=True): """Return SegmentationModel initialized with specified config and weights.""" model = YOLOv10PGTDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) if weights: model.load(weights) return model def get_validator(self): """Return an instance of SegmentationValidator for validation of YOLO model.""" self.loss_names = "box_om", "cls_om", "dfl_om", "box_oo", "cls_oo", "dfl_oo", "pgt_loss", return YOLOv10PGTDetectionValidator( self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks ) def plot_training_samples(self, batch, ni): """Creates a plot of training sample images with labels and box coordinates.""" plot_images( batch["img"], batch["batch_idx"], batch["cls"].squeeze(-1), batch["bboxes"], masks=batch["masks"], paths=batch["im_file"], fname=self.save_dir / f"train_batch{ni}.jpg", on_plot=self.on_plot, ) def plot_metrics(self): """Plots training/val metrics.""" plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png