mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Task augment (#2924)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f4b34fc30b
commit
c050b2d1a8
@ -54,6 +54,22 @@ class BaseModel(nn.Module):
|
||||
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
||||
augment (bool): Augment image during prediction, defaults to False.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The last output of the model.
|
||||
"""
|
||||
if augment:
|
||||
return self._predict_augment(x)
|
||||
return self._predict_once(x, profile, visualize)
|
||||
|
||||
def _predict_once(self, x, profile=False, visualize=False):
|
||||
"""
|
||||
Perform a forward pass through the network.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor to the model.
|
||||
profile (bool): Print the computation time of each layer if True, defaults to False.
|
||||
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The last output of the model.
|
||||
"""
|
||||
@ -69,6 +85,13 @@ class BaseModel(nn.Module):
|
||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||
return x
|
||||
|
||||
def _predict_augment(self, x):
|
||||
"""Perform augmentations on input image x and return augmented inference."""
|
||||
LOGGER.warning(
|
||||
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
||||
)
|
||||
return self._predict_once(x)
|
||||
|
||||
def _profile_one_layer(self, m, x, dt):
|
||||
"""
|
||||
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
||||
@ -225,13 +248,7 @@ class DetectionModel(BaseModel):
|
||||
self.info()
|
||||
LOGGER.info('')
|
||||
|
||||
def predict(self, x, augment=False, profile=False, visualize=False):
|
||||
"""Run forward pass on input image(s) with optional augmentation and profiling."""
|
||||
if augment:
|
||||
return self._forward_augment(x) # augmented inference, None
|
||||
return super().predict(x, profile=profile, visualize=visualize) # single-scale inference, train
|
||||
|
||||
def _forward_augment(self, x):
|
||||
def _predict_augment(self, x):
|
||||
"""Perform augmentations on input image x and return augmented inference and train outputs."""
|
||||
img_size = x.shape[-2:] # height, width
|
||||
s = [1, 0.83, 0.67] # scales
|
||||
@ -279,13 +296,16 @@ class SegmentationModel(DetectionModel):
|
||||
"""Initialize YOLOv8 segmentation model with given config and parameters."""
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
def _forward_augment(self, x):
|
||||
"""Undocumented function."""
|
||||
raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!'))
|
||||
|
||||
def init_criterion(self):
|
||||
return v8SegmentationLoss(self)
|
||||
|
||||
def _predict_augment(self, x):
|
||||
"""Perform augmentations on input image x and return augmented inference."""
|
||||
LOGGER.warning(
|
||||
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
||||
)
|
||||
return self._predict_once(x)
|
||||
|
||||
|
||||
class PoseModel(DetectionModel):
|
||||
"""YOLOv8 pose model."""
|
||||
@ -302,9 +322,12 @@ class PoseModel(DetectionModel):
|
||||
def init_criterion(self):
|
||||
return v8PoseLoss(self)
|
||||
|
||||
def _forward_augment(self, x):
|
||||
"""Undocumented function."""
|
||||
raise NotImplementedError(emojis('WARNING ⚠️ PoseModel has not supported augment inference yet!'))
|
||||
def _predict_augment(self, x):
|
||||
"""Perform augmentations on input image x and return augmented inference."""
|
||||
LOGGER.warning(
|
||||
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
||||
)
|
||||
return self._predict_once(x)
|
||||
|
||||
|
||||
class ClassificationModel(BaseModel):
|
||||
@ -448,10 +471,6 @@ class RTDETRDetectionModel(DetectionModel):
|
||||
x = head([y[j] for j in head.f], batch) # head inference
|
||||
return x
|
||||
|
||||
def _forward_augment(self, x):
|
||||
"""Undocumented function."""
|
||||
raise NotImplementedError(emojis('WARNING ⚠️ RTDETRModel has not supported augment inference yet!'))
|
||||
|
||||
|
||||
class Ensemble(nn.ModuleList):
|
||||
"""Ensemble of models."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user