mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-01 15:15:39 +08:00 
			
		
		
		
	Pass callbacks to validator (#7320)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
		
							parent
							
								
									072291bc78
								
							
						
					
					
						commit
						2f9ec8c0b4
					
				| @ -110,7 +110,7 @@ class ClassificationTrainer(BaseTrainer): | ||||
|     def get_validator(self): | ||||
|         """Returns an instance of ClassificationValidator for validation.""" | ||||
|         self.loss_names = ['loss'] | ||||
|         return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir) | ||||
|         return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir, _callbacks=self.callbacks) | ||||
| 
 | ||||
|     def label_loss_items(self, loss_items=None, prefix='train'): | ||||
|         """ | ||||
|  | ||||
| @ -89,7 +89,10 @@ class DetectionTrainer(BaseTrainer): | ||||
|     def get_validator(self): | ||||
|         """Returns a DetectionValidator for YOLO model validation.""" | ||||
|         self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' | ||||
|         return yolo.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) | ||||
|         return yolo.detect.DetectionValidator(self.test_loader, | ||||
|                                               save_dir=self.save_dir, | ||||
|                                               args=copy(self.args), | ||||
|                                               _callbacks=self.callbacks) | ||||
| 
 | ||||
|     def label_loss_items(self, loss_items=None, prefix='train'): | ||||
|         """ | ||||
|  | ||||
| @ -49,7 +49,10 @@ class PoseTrainer(yolo.detect.DetectionTrainer): | ||||
|     def get_validator(self): | ||||
|         """Returns an instance of the PoseValidator class for validation.""" | ||||
|         self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss' | ||||
|         return yolo.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) | ||||
|         return yolo.pose.PoseValidator(self.test_loader, | ||||
|                                        save_dir=self.save_dir, | ||||
|                                        args=copy(self.args), | ||||
|                                        _callbacks=self.callbacks) | ||||
| 
 | ||||
|     def plot_training_samples(self, batch, ni): | ||||
|         """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.""" | ||||
|  | ||||
| @ -40,7 +40,10 @@ class SegmentationTrainer(yolo.detect.DetectionTrainer): | ||||
|     def get_validator(self): | ||||
|         """Return an instance of SegmentationValidator for validation of YOLO model.""" | ||||
|         self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' | ||||
|         return yolo.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) | ||||
|         return yolo.segment.SegmentationValidator(self.test_loader, | ||||
|                                                   save_dir=self.save_dir, | ||||
|                                                   args=copy(self.args), | ||||
|                                                   _callbacks=self.callbacks) | ||||
| 
 | ||||
|     def plot_training_samples(self, batch, ni): | ||||
|         """Creates a plot of training sample images with labels and box coordinates.""" | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Thomas de Lange
						Thomas de Lange