mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 08:56:11 +08:00 
			
		
		
		
	Allow setting model attributes before training (#45)
This commit is contained in:
		
							parent
							
								
									832ea56eb4
								
							
						
					
					
						commit
						db1031a1a9
					
				@ -133,6 +133,7 @@ class BaseTrainer:
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        Builds dataloaders and optimizer on correct rank process
 | 
					        Builds dataloaders and optimizer on correct rank process
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					        self.set_model_attributes()
 | 
				
			||||||
        self.optimizer = build_optimizer(model=self.model,
 | 
					        self.optimizer = build_optimizer(model=self.model,
 | 
				
			||||||
                                         name=self.args.optimizer,
 | 
					                                         name=self.args.optimizer,
 | 
				
			||||||
                                         lr=self.args.lr0,
 | 
					                                         lr=self.args.lr0,
 | 
				
			||||||
@ -146,19 +147,6 @@ class BaseTrainer:
 | 
				
			|||||||
            print("created testloader :", rank)
 | 
					            print("created testloader :", rank)
 | 
				
			||||||
            self.console.info(self.progress_string())
 | 
					            self.console.info(self.progress_string())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _set_model_attributes(self):
 | 
					 | 
				
			||||||
        # TODO: fix and use after self.data_dict is available
 | 
					 | 
				
			||||||
        '''
 | 
					 | 
				
			||||||
        head = utils.torch_utils.de_parallel(self.model).model[-1]
 | 
					 | 
				
			||||||
        self.args.box *= 3 / head.nl  # scale to layers
 | 
					 | 
				
			||||||
        self.args.cls *= head.nc / 80 * 3 / head.nl  # scale to classes and layers
 | 
					 | 
				
			||||||
        self.args.obj *= (self.args.img_size / 640) ** 2 * 3 / nl  # scale to image size and layers
 | 
					 | 
				
			||||||
        model.nc = nc  # attach number of classes to model
 | 
					 | 
				
			||||||
        model.hyp = hyp  # attach hyperparameters to model
 | 
					 | 
				
			||||||
        model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc  # attach class weights
 | 
					 | 
				
			||||||
        model.names = names
 | 
					 | 
				
			||||||
        '''
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _do_train(self, rank, world_size):
 | 
					    def _do_train(self, rank, world_size):
 | 
				
			||||||
        if world_size > 1:
 | 
					        if world_size > 1:
 | 
				
			||||||
            self._setup_ddp(rank, world_size)
 | 
					            self._setup_ddp(rank, world_size)
 | 
				
			||||||
@ -302,6 +290,12 @@ class BaseTrainer:
 | 
				
			|||||||
        if not self.best_fitness or self.best_fitness < self.fitness:
 | 
					        if not self.best_fitness or self.best_fitness < self.fitness:
 | 
				
			||||||
            self.best_fitness = self.fitness
 | 
					            self.best_fitness = self.fitness
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_model_attributes(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        To set or update model parameters before training.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def build_targets(self, preds, targets):
 | 
					    def build_targets(self, preds, targets):
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -54,6 +54,16 @@ class SegmentationTrainer(BaseTrainer):
 | 
				
			|||||||
            model.load(weights)
 | 
					            model.load(weights)
 | 
				
			||||||
        return model
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_model_attributes(self):
 | 
				
			||||||
 | 
					        nl = de_parallel(self.model).model[-1].nl  # number of detection layers (to scale hyps)
 | 
				
			||||||
 | 
					        self.args.box *= 3 / nl  # scale to layers
 | 
				
			||||||
 | 
					        self.args.cls *= self.data["nc"] / 80 * 3 / nl  # scale to classes and layers
 | 
				
			||||||
 | 
					        self.args.obj *= (self.args.img_size / 640) ** 2 * 3 / nl  # scale to image size and layers
 | 
				
			||||||
 | 
					        self.model.nc = self.data["nc"]  # attach number of classes to model
 | 
				
			||||||
 | 
					        self.model.args = self.args  # attach hyperparameters to model
 | 
				
			||||||
 | 
					        # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
 | 
				
			||||||
 | 
					        self.model.names = self.data["names"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_validator(self):
 | 
					    def get_validator(self):
 | 
				
			||||||
        return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)
 | 
					        return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user