mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 00:45:38 +08:00 
			
		
		
		
	Task augment (#2924)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									f4b34fc30b
								
							
						
					
					
						commit
						c050b2d1a8
					
				@ -54,6 +54,22 @@ class BaseModel(nn.Module):
 | 
				
			|||||||
            visualize (bool): Save the feature maps of the model if True, defaults to False.
 | 
					            visualize (bool): Save the feature maps of the model if True, defaults to False.
 | 
				
			||||||
            augment (bool): Augment image during prediction, defaults to False.
 | 
					            augment (bool): Augment image during prediction, defaults to False.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            (torch.Tensor): The last output of the model.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if augment:
 | 
				
			||||||
 | 
					            return self._predict_augment(x)
 | 
				
			||||||
 | 
					        return self._predict_once(x, profile, visualize)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _predict_once(self, x, profile=False, visualize=False):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Perform a forward pass through the network.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            x (torch.Tensor): The input tensor to the model.
 | 
				
			||||||
 | 
					            profile (bool):  Print the computation time of each layer if True, defaults to False.
 | 
				
			||||||
 | 
					            visualize (bool): Save the feature maps of the model if True, defaults to False.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            (torch.Tensor): The last output of the model.
 | 
					            (torch.Tensor): The last output of the model.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
@ -69,6 +85,13 @@ class BaseModel(nn.Module):
 | 
				
			|||||||
                feature_visualization(x, m.type, m.i, save_dir=visualize)
 | 
					                feature_visualization(x, m.type, m.i, save_dir=visualize)
 | 
				
			||||||
        return x
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _predict_augment(self, x):
 | 
				
			||||||
 | 
					        """Perform augmentations on input image x and return augmented inference."""
 | 
				
			||||||
 | 
					        LOGGER.warning(
 | 
				
			||||||
 | 
					            f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return self._predict_once(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _profile_one_layer(self, m, x, dt):
 | 
					    def _profile_one_layer(self, m, x, dt):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Profile the computation time and FLOPs of a single layer of the model on a given input.
 | 
					        Profile the computation time and FLOPs of a single layer of the model on a given input.
 | 
				
			||||||
@ -225,13 +248,7 @@ class DetectionModel(BaseModel):
 | 
				
			|||||||
            self.info()
 | 
					            self.info()
 | 
				
			||||||
            LOGGER.info('')
 | 
					            LOGGER.info('')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def predict(self, x, augment=False, profile=False, visualize=False):
 | 
					    def _predict_augment(self, x):
 | 
				
			||||||
        """Run forward pass on input image(s) with optional augmentation and profiling."""
 | 
					 | 
				
			||||||
        if augment:
 | 
					 | 
				
			||||||
            return self._forward_augment(x)  # augmented inference, None
 | 
					 | 
				
			||||||
        return super().predict(x, profile=profile, visualize=visualize)  # single-scale inference, train
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def _forward_augment(self, x):
 | 
					 | 
				
			||||||
        """Perform augmentations on input image x and return augmented inference and train outputs."""
 | 
					        """Perform augmentations on input image x and return augmented inference and train outputs."""
 | 
				
			||||||
        img_size = x.shape[-2:]  # height, width
 | 
					        img_size = x.shape[-2:]  # height, width
 | 
				
			||||||
        s = [1, 0.83, 0.67]  # scales
 | 
					        s = [1, 0.83, 0.67]  # scales
 | 
				
			||||||
@ -279,13 +296,16 @@ class SegmentationModel(DetectionModel):
 | 
				
			|||||||
        """Initialize YOLOv8 segmentation model with given config and parameters."""
 | 
					        """Initialize YOLOv8 segmentation model with given config and parameters."""
 | 
				
			||||||
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
 | 
					        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _forward_augment(self, x):
 | 
					 | 
				
			||||||
        """Undocumented function."""
 | 
					 | 
				
			||||||
        raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!'))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def init_criterion(self):
 | 
					    def init_criterion(self):
 | 
				
			||||||
        return v8SegmentationLoss(self)
 | 
					        return v8SegmentationLoss(self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _predict_augment(self, x):
 | 
				
			||||||
 | 
					        """Perform augmentations on input image x and return augmented inference."""
 | 
				
			||||||
 | 
					        LOGGER.warning(
 | 
				
			||||||
 | 
					            f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return self._predict_once(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PoseModel(DetectionModel):
 | 
					class PoseModel(DetectionModel):
 | 
				
			||||||
    """YOLOv8 pose model."""
 | 
					    """YOLOv8 pose model."""
 | 
				
			||||||
@ -302,9 +322,12 @@ class PoseModel(DetectionModel):
 | 
				
			|||||||
    def init_criterion(self):
 | 
					    def init_criterion(self):
 | 
				
			||||||
        return v8PoseLoss(self)
 | 
					        return v8PoseLoss(self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _forward_augment(self, x):
 | 
					    def _predict_augment(self, x):
 | 
				
			||||||
        """Undocumented function."""
 | 
					        """Perform augmentations on input image x and return augmented inference."""
 | 
				
			||||||
        raise NotImplementedError(emojis('WARNING ⚠️ PoseModel has not supported augment inference yet!'))
 | 
					        LOGGER.warning(
 | 
				
			||||||
 | 
					            f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return self._predict_once(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ClassificationModel(BaseModel):
 | 
					class ClassificationModel(BaseModel):
 | 
				
			||||||
@ -448,10 +471,6 @@ class RTDETRDetectionModel(DetectionModel):
 | 
				
			|||||||
        x = head([y[j] for j in head.f], batch)  # head inference
 | 
					        x = head([y[j] for j in head.f], batch)  # head inference
 | 
				
			||||||
        return x
 | 
					        return x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _forward_augment(self, x):
 | 
					 | 
				
			||||||
        """Undocumented function."""
 | 
					 | 
				
			||||||
        raise NotImplementedError(emojis('WARNING ⚠️ RTDETRModel has not supported augment inference yet!'))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Ensemble(nn.ModuleList):
 | 
					class Ensemble(nn.ModuleList):
 | 
				
			||||||
    """Ensemble of models."""
 | 
					    """Ensemble of models."""
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user