mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
ultralytics 8.0.229
add model.embed()
method (#7098)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
38eaf5e29f
commit
5b3e20379f
@ -125,5 +125,6 @@ Monitoring workouts through pose estimation with [Ultralytics YOLOv8](https://gi
|
||||
| `visualize` | `bool` | `False` | visualize model features |
|
||||
| `augment` | `bool` | `False` | apply image augmentation to prediction sources |
|
||||
| `agnostic_nms` | `bool` | `False` | class-agnostic NMS |
|
||||
| `classes` | `list[int]` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
|
||||
| `retina_masks` | `bool` | `False` | use high-resolution segmentation masks |
|
||||
| `classes` | `None or list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
|
||||
| `embed` | `list[int]` | `None` | return feature vectors/embeddings from given layers |
|
||||
|
@ -355,8 +355,9 @@ Inference arguments:
|
||||
| `visualize` | `bool` | `False` | visualize model features |
|
||||
| `augment` | `bool` | `False` | apply image augmentation to prediction sources |
|
||||
| `agnostic_nms` | `bool` | `False` | class-agnostic NMS |
|
||||
| `classes` | `list[int]` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
|
||||
| `retina_masks` | `bool` | `False` | use high-resolution segmentation masks |
|
||||
| `classes` | `None or list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
|
||||
| `embed` | `list[int]` | `None` | return feature vectors/embeddings from given layers |
|
||||
|
||||
Visualization arguments:
|
||||
|
||||
|
@ -18,3 +18,7 @@ keywords: Ultralytics, AutoBackend, check_class_names, YOLO, YOLO models, optimi
|
||||
## ::: ultralytics.nn.autobackend.check_class_names
|
||||
|
||||
<br><br>
|
||||
|
||||
## ::: ultralytics.nn.autobackend.default_class_names
|
||||
|
||||
<br><br>
|
||||
|
@ -156,8 +156,9 @@ Inference arguments:
|
||||
| `visualize` | `bool` | `False` | visualize model features |
|
||||
| `augment` | `bool` | `False` | apply image augmentation to prediction sources |
|
||||
| `agnostic_nms` | `bool` | `False` | class-agnostic NMS |
|
||||
| `classes` | `list[int]` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
|
||||
| `retina_masks` | `bool` | `False` | use high-resolution segmentation masks |
|
||||
| `classes` | `None or list` | `None` | filter results by class, i.e. classes=0, or classes=[0,2,3] |
|
||||
| `embed` | `list[int]` | `None` | return feature vectors/embeddings from given layers |
|
||||
|
||||
Visualization arguments:
|
||||
|
||||
|
@ -511,3 +511,13 @@ def test_model_tune():
|
||||
"""Tune YOLO model for performance."""
|
||||
YOLO('yolov8n-pose.pt').tune(data='coco8-pose.yaml', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
|
||||
YOLO('yolov8n-cls.pt').tune(data='imagenet10', plots=False, imgsz=32, epochs=1, iterations=2, device='cpu')
|
||||
|
||||
|
||||
def test_model_embeddings():
|
||||
"""Test YOLO model embeddings."""
|
||||
model_detect = YOLO(MODEL)
|
||||
model_segment = YOLO(WEIGHTS_DIR / 'yolov8n-seg.pt')
|
||||
|
||||
for batch in [SOURCE], [SOURCE, SOURCE]: # test batch size 1 and 2
|
||||
assert len(model_detect.embed(source=batch, imgsz=32)) == len(batch)
|
||||
assert len(model_segment.embed(source=batch, imgsz=32)) == len(batch)
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = '8.0.228'
|
||||
__version__ = '8.0.229'
|
||||
|
||||
from ultralytics.models import RTDETR, SAM, YOLO
|
||||
from ultralytics.models.fastsam import FastSAM
|
||||
|
@ -61,6 +61,7 @@ augment: False # (bool) apply image augmentation to prediction sources
|
||||
agnostic_nms: False # (bool) class-agnostic NMS
|
||||
classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3]
|
||||
retina_masks: False # (bool) use high-resolution segmentation masks
|
||||
embed: # (list[int], optional) return feature vectors/embeddings from given layers
|
||||
|
||||
# Visualize settings ---------------------------------------------------------------------------------------------------
|
||||
show: False # (bool) show predicted images and videos if environment allows
|
||||
|
@ -94,7 +94,7 @@ class Model(nn.Module):
|
||||
self._load(model, task)
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
"""Calls the predict() method with given arguments to perform object detection."""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@ -201,6 +201,24 @@ class Model(nn.Module):
|
||||
self._check_is_pytorch_model()
|
||||
self.model.fuse()
|
||||
|
||||
def embed(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
Calls the predict() method and returns image embeddings.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
||||
Accepts all source types accepted by the YOLO model.
|
||||
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
||||
**kwargs : Additional keyword arguments passed to the predictor.
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(List[torch.Tensor]): A list of image embeddings.
|
||||
"""
|
||||
if not kwargs.get('embed'):
|
||||
kwargs['embed'] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
def predict(self, source=None, stream=False, predictor=None, **kwargs):
|
||||
"""
|
||||
Perform prediction using the YOLO model.
|
||||
|
@ -134,7 +134,7 @@ class BasePredictor:
|
||||
"""Runs inference on a given image using the specified model and arguments."""
|
||||
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
|
||||
mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
|
||||
return self.model(im, augment=self.args.augment, visualize=visualize)
|
||||
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
@ -263,6 +263,9 @@ class BasePredictor:
|
||||
# Inference
|
||||
with profilers[1]:
|
||||
preds = self.inference(im, *args, **kwargs)
|
||||
if self.args.embed:
|
||||
yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
|
||||
continue
|
||||
|
||||
# Postprocess
|
||||
with profilers[2]:
|
||||
|
@ -333,7 +333,7 @@ class AutoBackend(nn.Module):
|
||||
|
||||
self.__dict__.update(locals()) # assign all variables to self
|
||||
|
||||
def forward(self, im, augment=False, visualize=False):
|
||||
def forward(self, im, augment=False, visualize=False, embed=None):
|
||||
"""
|
||||
Runs inference on the YOLOv8 MultiBackend model.
|
||||
|
||||
@ -341,6 +341,7 @@ class AutoBackend(nn.Module):
|
||||
im (torch.Tensor): The image tensor to perform inference on.
|
||||
augment (bool): whether to perform data augmentation during inference, defaults to False
|
||||
visualize (bool): whether to visualize the output predictions, defaults to False
|
||||
embed (list, optional): A list of feature vectors/embeddings to return.
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True)
|
||||
@ -352,7 +353,7 @@ class AutoBackend(nn.Module):
|
||||
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
|
||||
|
||||
if self.pt or self.nn_module: # PyTorch
|
||||
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
|
||||
y = self.model(im, augment=augment, visualize=visualize, embed=embed)
|
||||
elif self.jit: # TorchScript
|
||||
y = self.model(im)
|
||||
elif self.dnn: # ONNX OpenCV DNN
|
||||
|
@ -41,7 +41,7 @@ class BaseModel(nn.Module):
|
||||
return self.loss(x, *args, **kwargs)
|
||||
return self.predict(x, *args, **kwargs)
|
||||
|
||||
def predict(self, x, profile=False, visualize=False, augment=False):
|
||||
def predict(self, x, profile=False, visualize=False, augment=False, embed=None):
|
||||
"""
|
||||
Perform a forward pass through the network.
|
||||
|
||||
@ -50,15 +50,16 @@ class BaseModel(nn.Module):
|
||||
profile (bool): Print the computation time of each layer if True, defaults to False.
|
||||
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
||||
augment (bool): Augment image during prediction, defaults to False.
|
||||
embed (list, optional): A list of feature vectors/embeddings to return.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The last output of the model.
|
||||
"""
|
||||
if augment:
|
||||
return self._predict_augment(x)
|
||||
return self._predict_once(x, profile, visualize)
|
||||
return self._predict_once(x, profile, visualize, embed)
|
||||
|
||||
def _predict_once(self, x, profile=False, visualize=False):
|
||||
def _predict_once(self, x, profile=False, visualize=False, embed=None):
|
||||
"""
|
||||
Perform a forward pass through the network.
|
||||
|
||||
@ -66,11 +67,12 @@ class BaseModel(nn.Module):
|
||||
x (torch.Tensor): The input tensor to the model.
|
||||
profile (bool): Print the computation time of each layer if True, defaults to False.
|
||||
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
||||
embed (list, optional): A list of feature vectors/embeddings to return.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The last output of the model.
|
||||
"""
|
||||
y, dt = [], [] # outputs
|
||||
y, dt, embeddings = [], [], [] # outputs
|
||||
for m in self.model:
|
||||
if m.f != -1: # if not from previous layer
|
||||
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
||||
@ -80,6 +82,10 @@ class BaseModel(nn.Module):
|
||||
y.append(x if m.i in self.save else None) # save output
|
||||
if visualize:
|
||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||
if embed and m.i in embed:
|
||||
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
||||
if m.i == max(embed):
|
||||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||
return x
|
||||
|
||||
def _predict_augment(self, x):
|
||||
@ -454,7 +460,7 @@ class RTDETRDetectionModel(DetectionModel):
|
||||
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
|
||||
device=img.device)
|
||||
|
||||
def predict(self, x, profile=False, visualize=False, batch=None, augment=False):
|
||||
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
|
||||
"""
|
||||
Perform a forward pass through the model.
|
||||
|
||||
@ -464,11 +470,12 @@ class RTDETRDetectionModel(DetectionModel):
|
||||
visualize (bool, optional): If True, save feature maps for visualization. Defaults to False.
|
||||
batch (dict, optional): Ground truth data for evaluation. Defaults to None.
|
||||
augment (bool, optional): If True, perform data augmentation during inference. Defaults to False.
|
||||
embed (list, optional): A list of feature vectors/embeddings to return.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Model's output tensor.
|
||||
"""
|
||||
y, dt = [], [] # outputs
|
||||
y, dt, embeddings = [], [], [] # outputs
|
||||
for m in self.model[:-1]: # except the head part
|
||||
if m.f != -1: # if not from previous layer
|
||||
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
|
||||
@ -478,6 +485,10 @@ class RTDETRDetectionModel(DetectionModel):
|
||||
y.append(x if m.i in self.save else None) # save output
|
||||
if visualize:
|
||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||
if embed and m.i in embed:
|
||||
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
||||
if m.i == max(embed):
|
||||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||
head = self.model[-1]
|
||||
x = head([y[j] for j in head.f], batch) # head inference
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user