From 379f31904a2c2e0c88849df0cc56fb6dd0936e0f Mon Sep 17 00:00:00 2001
From: Glenn Jocher <glenn.jocher@ultralytics.com>
Date: Sun, 30 Jul 2023 22:24:28 +0200
Subject: [PATCH] Fix per-layer FLOPs profiling (#4048)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
 ultralytics/nn/tasks.py | 15 +++++++--------
 1 file changed, 7 insertions(+), 8 deletions(-)

diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py
index 95c4d342..afd327df 100644
--- a/ultralytics/nn/tasks.py
+++ b/ultralytics/nn/tasks.py
@@ -87,9 +87,8 @@ class BaseModel(nn.Module):
 
     def _predict_augment(self, x):
         """Perform augmentations on input image x and return augmented inference."""
-        LOGGER.warning(
-            f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
-        )
+        LOGGER.warning(f'WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. '
+                       f'Reverting to single-scale inference instead.')
         return self._predict_once(x)
 
     def _profile_one_layer(self, m, x, dt):
@@ -105,15 +104,15 @@ class BaseModel(nn.Module):
         Returns:
             None
         """
-        c = m == self.model[-1]  # is final layer, copy input as inplace fix
-        o = thop.profile(m, inputs=[x.clone() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0  # FLOPs
+        c = m == self.model[-1] and isinstance(x, list)  # is final layer list, copy input as inplace fix
+        flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0  # FLOPs
         t = time_sync()
         for _ in range(10):
-            m(x.clone() if c else x)
+            m(x.copy() if c else x)
         dt.append((time_sync() - t) * 100)
         if m == self.model[0]:
             LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s}  module")
-        LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f}  {m.type}')
+        LOGGER.info(f'{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f}  {m.type}')
         if c:
             LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s}  Total")
 
@@ -338,7 +337,7 @@ class ClassificationModel(BaseModel):
     """YOLOv8 classification model."""
 
     def __init__(self,
-                 cfg=None,
+                 cfg='yolov8n-cls.yaml',
                  model=None,
                  ch=3,
                  nc=None,