mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Task augment (#2924)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f4b34fc30b
commit
c050b2d1a8
@ -54,6 +54,22 @@ class BaseModel(nn.Module):
|
|||||||
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
||||||
augment (bool): Augment image during prediction, defaults to False.
|
augment (bool): Augment image during prediction, defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(torch.Tensor): The last output of the model.
|
||||||
|
"""
|
||||||
|
if augment:
|
||||||
|
return self._predict_augment(x)
|
||||||
|
return self._predict_once(x, profile, visualize)
|
||||||
|
|
||||||
|
def _predict_once(self, x, profile=False, visualize=False):
|
||||||
|
"""
|
||||||
|
Perform a forward pass through the network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor to the model.
|
||||||
|
profile (bool): Print the computation time of each layer if True, defaults to False.
|
||||||
|
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(torch.Tensor): The last output of the model.
|
(torch.Tensor): The last output of the model.
|
||||||
"""
|
"""
|
||||||
@ -69,6 +85,13 @@ class BaseModel(nn.Module):
|
|||||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def _predict_augment(self, x):
|
||||||
|
"""Perform augmentations on input image x and return augmented inference."""
|
||||||
|
LOGGER.warning(
|
||||||
|
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
||||||
|
)
|
||||||
|
return self._predict_once(x)
|
||||||
|
|
||||||
def _profile_one_layer(self, m, x, dt):
|
def _profile_one_layer(self, m, x, dt):
|
||||||
"""
|
"""
|
||||||
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
||||||
@ -225,13 +248,7 @@ class DetectionModel(BaseModel):
|
|||||||
self.info()
|
self.info()
|
||||||
LOGGER.info('')
|
LOGGER.info('')
|
||||||
|
|
||||||
def predict(self, x, augment=False, profile=False, visualize=False):
|
def _predict_augment(self, x):
|
||||||
"""Run forward pass on input image(s) with optional augmentation and profiling."""
|
|
||||||
if augment:
|
|
||||||
return self._forward_augment(x) # augmented inference, None
|
|
||||||
return super().predict(x, profile=profile, visualize=visualize) # single-scale inference, train
|
|
||||||
|
|
||||||
def _forward_augment(self, x):
|
|
||||||
"""Perform augmentations on input image x and return augmented inference and train outputs."""
|
"""Perform augmentations on input image x and return augmented inference and train outputs."""
|
||||||
img_size = x.shape[-2:] # height, width
|
img_size = x.shape[-2:] # height, width
|
||||||
s = [1, 0.83, 0.67] # scales
|
s = [1, 0.83, 0.67] # scales
|
||||||
@ -279,13 +296,16 @@ class SegmentationModel(DetectionModel):
|
|||||||
"""Initialize YOLOv8 segmentation model with given config and parameters."""
|
"""Initialize YOLOv8 segmentation model with given config and parameters."""
|
||||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||||
|
|
||||||
def _forward_augment(self, x):
|
|
||||||
"""Undocumented function."""
|
|
||||||
raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!'))
|
|
||||||
|
|
||||||
def init_criterion(self):
|
def init_criterion(self):
|
||||||
return v8SegmentationLoss(self)
|
return v8SegmentationLoss(self)
|
||||||
|
|
||||||
|
def _predict_augment(self, x):
|
||||||
|
"""Perform augmentations on input image x and return augmented inference."""
|
||||||
|
LOGGER.warning(
|
||||||
|
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
||||||
|
)
|
||||||
|
return self._predict_once(x)
|
||||||
|
|
||||||
|
|
||||||
class PoseModel(DetectionModel):
|
class PoseModel(DetectionModel):
|
||||||
"""YOLOv8 pose model."""
|
"""YOLOv8 pose model."""
|
||||||
@ -302,9 +322,12 @@ class PoseModel(DetectionModel):
|
|||||||
def init_criterion(self):
|
def init_criterion(self):
|
||||||
return v8PoseLoss(self)
|
return v8PoseLoss(self)
|
||||||
|
|
||||||
def _forward_augment(self, x):
|
def _predict_augment(self, x):
|
||||||
"""Undocumented function."""
|
"""Perform augmentations on input image x and return augmented inference."""
|
||||||
raise NotImplementedError(emojis('WARNING ⚠️ PoseModel has not supported augment inference yet!'))
|
LOGGER.warning(
|
||||||
|
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
||||||
|
)
|
||||||
|
return self._predict_once(x)
|
||||||
|
|
||||||
|
|
||||||
class ClassificationModel(BaseModel):
|
class ClassificationModel(BaseModel):
|
||||||
@ -448,10 +471,6 @@ class RTDETRDetectionModel(DetectionModel):
|
|||||||
x = head([y[j] for j in head.f], batch) # head inference
|
x = head([y[j] for j in head.f], batch) # head inference
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def _forward_augment(self, x):
|
|
||||||
"""Undocumented function."""
|
|
||||||
raise NotImplementedError(emojis('WARNING ⚠️ RTDETRModel has not supported augment inference yet!'))
|
|
||||||
|
|
||||||
|
|
||||||
class Ensemble(nn.ModuleList):
|
class Ensemble(nn.ModuleList):
|
||||||
"""Ensemble of models."""
|
"""Ensemble of models."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user