mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Pass callbacks to validator (#7320)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
072291bc78
commit
2f9ec8c0b4
@ -110,7 +110,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
def get_validator(self):
|
||||
"""Returns an instance of ClassificationValidator for validation."""
|
||||
self.loss_names = ['loss']
|
||||
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir)
|
||||
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks)
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
|
@ -89,7 +89,10 @@ class DetectionTrainer(BaseTrainer):
|
||||
def get_validator(self):
|
||||
"""Returns a DetectionValidator for YOLO model validation."""
|
||||
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
|
||||
return yolo.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
return yolo.detect.DetectionValidator(self.test_loader,
|
||||
save_dir=self.save_dir,
|
||||
args=copy(self.args),
|
||||
_callbacks=self.callbacks)
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
|
@ -49,7 +49,10 @@ class PoseTrainer(yolo.detect.DetectionTrainer):
|
||||
def get_validator(self):
|
||||
"""Returns an instance of the PoseValidator class for validation."""
|
||||
self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
|
||||
return yolo.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
return yolo.pose.PoseValidator(self.test_loader,
|
||||
save_dir=self.save_dir,
|
||||
args=copy(self.args),
|
||||
_callbacks=self.callbacks)
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
|
||||
|
@ -40,7 +40,10 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
||||
def get_validator(self):
|
||||
"""Return an instance of SegmentationValidator for validation of YOLO model."""
|
||||
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
|
||||
return yolo.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
return yolo.segment.SegmentationValidator(self.test_loader,
|
||||
save_dir=self.save_dir,
|
||||
args=copy(self.args),
|
||||
_callbacks=self.callbacks)
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Creates a plot of training sample images with labels and box coordinates."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user