mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 08:56:11 +08:00 
			
		
		
		
	Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
		
			
				
	
	
		
			62 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			62 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Ultralytics YOLO 🚀, AGPL-3.0 license
 | 
						|
 | 
						|
import cv2
 | 
						|
import torch
 | 
						|
from PIL import Image
 | 
						|
 | 
						|
from ultralytics.engine.predictor import BasePredictor
 | 
						|
from ultralytics.engine.results import Results
 | 
						|
from ultralytics.utils import DEFAULT_CFG, ops
 | 
						|
 | 
						|
 | 
						|
class ClassificationPredictor(BasePredictor):
 | 
						|
    """
 | 
						|
    A class extending the BasePredictor class for prediction based on a classification model.
 | 
						|
 | 
						|
    Notes:
 | 
						|
        - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
 | 
						|
 | 
						|
    Example:
 | 
						|
        ```python
 | 
						|
        from ultralytics.utils import ASSETS
 | 
						|
        from ultralytics.models.yolo.classify import ClassificationPredictor
 | 
						|
 | 
						|
        args = dict(model='yolov8n-cls.pt', source=ASSETS)
 | 
						|
        predictor = ClassificationPredictor(overrides=args)
 | 
						|
        predictor.predict_cli()
 | 
						|
        ```
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
 | 
						|
        """Initializes ClassificationPredictor setting the task to 'classify'."""
 | 
						|
        super().__init__(cfg, overrides, _callbacks)
 | 
						|
        self.args.task = "classify"
 | 
						|
        self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor"
 | 
						|
 | 
						|
    def preprocess(self, img):
 | 
						|
        """Converts input image to model-compatible data type."""
 | 
						|
        if not isinstance(img, torch.Tensor):
 | 
						|
            is_legacy_transform = any(
 | 
						|
                self._legacy_transform_name in str(transform) for transform in self.transforms.transforms
 | 
						|
            )
 | 
						|
            if is_legacy_transform:  # to handle legacy transforms
 | 
						|
                img = torch.stack([self.transforms(im) for im in img], dim=0)
 | 
						|
            else:
 | 
						|
                img = torch.stack(
 | 
						|
                    [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
 | 
						|
                )
 | 
						|
        img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
 | 
						|
        return img.half() if self.model.fp16 else img.float()  # uint8 to fp16/32
 | 
						|
 | 
						|
    def postprocess(self, preds, img, orig_imgs):
 | 
						|
        """Post-processes predictions to return Results objects."""
 | 
						|
        if not isinstance(orig_imgs, list):  # input images are a torch.Tensor, not a list
 | 
						|
            orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
 | 
						|
 | 
						|
        results = []
 | 
						|
        for i, pred in enumerate(preds):
 | 
						|
            orig_img = orig_imgs[i]
 | 
						|
            img_path = self.batch[0][i]
 | 
						|
            results.append(Results(orig_img, path=img_path, names=self.model.names, probs=pred))
 | 
						|
        return results
 |