From 2f9ec8c0b4e27819e14e7794cb7287fd0fdba24a Mon Sep 17 00:00:00 2001 From: Thomas de Lange <33953384+ThomasDeLange@users.noreply.github.com> Date: Fri, 5 Jan 2024 09:08:17 +0000 Subject: [PATCH] Pass callbacks to validator (#7320) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- ultralytics/models/yolo/classify/train.py | 2 +- ultralytics/models/yolo/detect/train.py | 5 ++++- ultralytics/models/yolo/pose/train.py | 5 ++++- ultralytics/models/yolo/segment/train.py | 5 ++++- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/ultralytics/models/yolo/classify/train.py b/ultralytics/models/yolo/classify/train.py index c59f2853..3c23b2a0 100644 --- a/ultralytics/models/yolo/classify/train.py +++ b/ultralytics/models/yolo/classify/train.py @@ -110,7 +110,7 @@ class ClassificationTrainer(BaseTrainer): def get_validator(self): """Returns an instance of ClassificationValidator for validation.""" self.loss_names = ['loss'] - return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir) + return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks) def label_loss_items(self, loss_items=None, prefix='train'): """ diff --git a/ultralytics/models/yolo/detect/train.py b/ultralytics/models/yolo/detect/train.py index 5cfaa9f4..fc656984 100644 --- a/ultralytics/models/yolo/detect/train.py +++ b/ultralytics/models/yolo/detect/train.py @@ -89,7 +89,10 @@ class DetectionTrainer(BaseTrainer): def get_validator(self): """Returns a DetectionValidator for YOLO model validation.""" self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' - return yolo.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) + return yolo.detect.DetectionValidator(self.test_loader, + save_dir=self.save_dir, + args=copy(self.args), + _callbacks=self.callbacks) def label_loss_items(self, loss_items=None, prefix='train'): """ diff --git a/ultralytics/models/yolo/pose/train.py b/ultralytics/models/yolo/pose/train.py index 2d4f4e0d..c9ccf52e 100644 --- a/ultralytics/models/yolo/pose/train.py +++ b/ultralytics/models/yolo/pose/train.py @@ -49,7 +49,10 @@ class PoseTrainer(yolo.detect.DetectionTrainer): def get_validator(self): """Returns an instance of the PoseValidator class for validation.""" self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss' - return yolo.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) + return yolo.pose.PoseValidator(self.test_loader, + save_dir=self.save_dir, + args=copy(self.args), + _callbacks=self.callbacks) def plot_training_samples(self, batch, ni): """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.""" diff --git a/ultralytics/models/yolo/segment/train.py b/ultralytics/models/yolo/segment/train.py index b290192c..949f3cd6 100644 --- a/ultralytics/models/yolo/segment/train.py +++ b/ultralytics/models/yolo/segment/train.py @@ -40,7 +40,10 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer): def get_validator(self): """Return an instance of SegmentationValidator for validation of YOLO model.""" self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' - return yolo.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) + return yolo.segment.SegmentationValidator(self.test_loader, + save_dir=self.save_dir, + args=copy(self.args), + _callbacks=self.callbacks) def plot_training_samples(self, batch, ni): """Creates a plot of training sample images with labels and box coordinates."""