mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 08:56:11 +08:00 
			
		
		
		
	ultralytics 8.0.159 add Classify training resume feature (#4482)
				
					
				
			This commit is contained in:
		
							parent
							
								
									b2f279ffdd
								
							
						
					
					
						commit
						c0a9660310
					
				@ -1,6 +1,6 @@
 | 
				
			|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
 | 
					# Ultralytics YOLO 🚀, AGPL-3.0 license
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__version__ = '8.0.158'
 | 
					__version__ = '8.0.159'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ultralytics.hub import start
 | 
					from ultralytics.hub import start
 | 
				
			||||||
from ultralytics.models import RTDETR, SAM, YOLO
 | 
					from ultralytics.models import RTDETR, SAM, YOLO
 | 
				
			||||||
 | 
				
			|||||||
@ -419,7 +419,7 @@ def entrypoint(debug=''):
 | 
				
			|||||||
        overrides['source'] = DEFAULT_CFG.source or ASSETS
 | 
					        overrides['source'] = DEFAULT_CFG.source or ASSETS
 | 
				
			||||||
        LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
 | 
					        LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
 | 
				
			||||||
    elif mode in ('train', 'val'):
 | 
					    elif mode in ('train', 'val'):
 | 
				
			||||||
        if 'data' not in overrides:
 | 
					        if 'data' not in overrides and 'resume' not in overrides:
 | 
				
			||||||
            overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
 | 
					            overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
 | 
				
			||||||
            LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
 | 
					            LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
 | 
				
			||||||
    elif mode == 'export':
 | 
					    elif mode == 'export':
 | 
				
			||||||
 | 
				
			|||||||
@ -62,10 +62,10 @@ class ClassificationTrainer(BaseTrainer):
 | 
				
			|||||||
        if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
 | 
					        if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        model = str(self.model)
 | 
					        model, ckpt = str(self.model), None
 | 
				
			||||||
        # Load a YOLO model locally, from torchvision, or from Ultralytics assets
 | 
					        # Load a YOLO model locally, from torchvision, or from Ultralytics assets
 | 
				
			||||||
        if model.endswith('.pt'):
 | 
					        if model.endswith('.pt'):
 | 
				
			||||||
            self.model, _ = attempt_load_one_weight(model, device='cpu')
 | 
					            self.model, ckpt = attempt_load_one_weight(model, device='cpu')
 | 
				
			||||||
            for p in self.model.parameters():
 | 
					            for p in self.model.parameters():
 | 
				
			||||||
                p.requires_grad = True  # for training
 | 
					                p.requires_grad = True  # for training
 | 
				
			||||||
        elif model.split('.')[-1] in ('yaml', 'yml'):
 | 
					        elif model.split('.')[-1] in ('yaml', 'yml'):
 | 
				
			||||||
@ -76,7 +76,7 @@ class ClassificationTrainer(BaseTrainer):
 | 
				
			|||||||
            FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
 | 
					            FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
 | 
				
			||||||
        ClassificationModel.reshape_outputs(self.model, self.data['nc'])
 | 
					        ClassificationModel.reshape_outputs(self.model, self.data['nc'])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return  # do not return ckpt. Classification doesn't support resume
 | 
					        return ckpt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def build_dataset(self, img_path, mode='train', batch=None):
 | 
					    def build_dataset(self, img_path, mode='train', batch=None):
 | 
				
			||||||
        return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
 | 
					        return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
 | 
				
			||||||
@ -122,10 +122,6 @@ class ClassificationTrainer(BaseTrainer):
 | 
				
			|||||||
        loss_items = [round(float(loss_items), 5)]
 | 
					        loss_items = [round(float(loss_items), 5)]
 | 
				
			||||||
        return dict(zip(keys, loss_items))
 | 
					        return dict(zip(keys, loss_items))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def resume_training(self, ckpt):
 | 
					 | 
				
			||||||
        """Resumes training from a given checkpoint."""
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def plot_metrics(self):
 | 
					    def plot_metrics(self):
 | 
				
			||||||
        """Plots metrics from a CSV file."""
 | 
					        """Plots metrics from a CSV file."""
 | 
				
			||||||
        plot_results(file=self.csv, classify=True, on_plot=self.on_plot)  # save results.png
 | 
					        plot_results(file=self.csv, classify=True, on_plot=self.on_plot)  # save results.png
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user