From 4916014af266bea2955ab7a8555dea63dc06782b Mon Sep 17 00:00:00 2001
From: Glenn Jocher <glenn.jocher@ultralytics.com>
Date: Thu, 13 Apr 2023 20:19:42 +0200
Subject: [PATCH] `ultralytics 8.0.76` minor fixes and improvements (#2004)

Co-authored-by: Seungtaek Kim <seungtaek.kim.94@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: Ercalvez <45692523+Ercalvez@users.noreply.github.com>
Co-authored-by: Erwan CALVEZ <ecalvez@enib.fr>
---
 .github/workflows/ci.yaml           |  6 +++---
 docker/Dockerfile                   |  2 +-
 docker/Dockerfile-cpu               |  2 +-
 docs/usage/python.md                |  8 ++++----
 setup.py                            |  5 +++--
 tests/test_python.py                |  2 +-
 ultralytics/__init__.py             |  2 +-
 ultralytics/hub/session.py          |  4 +++-
 ultralytics/yolo/engine/model.py    |  2 +-
 ultralytics/yolo/engine/results.py  |  2 +-
 ultralytics/yolo/engine/trainer.py  |  4 ++--
 ultralytics/yolo/utils/errors.py    |  9 +++++++++
 ultralytics/yolo/utils/metrics.py   | 24 +++++++++++++++++++++---
 ultralytics/yolo/utils/ops.py       |  6 +++---
 ultralytics/yolo/v8/classify/val.py | 11 +++++++++--
 15 files changed, 63 insertions(+), 26 deletions(-)
 create mode 100644 ultralytics/yolo/utils/errors.py

diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 7a274cf1..723886d6 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -76,11 +76,11 @@ jobs:
         run: |
           python -m pip install --upgrade pip wheel
           if [ "${{ matrix.os }}" == "macos-latest" ]; then
-              pip install -e . coremltools openvino-dev tensorflow-macos tensorflowjs --extra-index-url https://download.pytorch.org/whl/cpu
+              pip install -e '.[export-macos]' --extra-index-url https://download.pytorch.org/whl/cpu
             else
-              pip install -e . coremltools openvino-dev tensorflow-cpu tensorflowjs --extra-index-url https://download.pytorch.org/whl/cpu
+              pip install -e '.[export-cpu]' --extra-index-url https://download.pytorch.org/whl/cpu
           fi
-          yolo export format=tflite
+          yolo export format=tflite imgsz=32
       - name: Check environment
         run: |
           echo "RUNNER_OS is ${{ runner.os }}"
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 22cd80bf..c79518a8 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -30,7 +30,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /u
 
 # Install pip packages
 RUN python3 -m pip install --upgrade pip wheel
-RUN pip install --no-cache '.[export]' albumentations comet gsutil notebook
+RUN pip install --no-cache . albumentations comet gsutil notebook
 
 # Set environment variables
 ENV OMP_NUM_THREADS=1
diff --git a/docker/Dockerfile-cpu b/docker/Dockerfile-cpu
index bb4bd5b4..a9c3d790 100644
--- a/docker/Dockerfile-cpu
+++ b/docker/Dockerfile-cpu
@@ -26,7 +26,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /u
 
 # Install pip packages
 RUN python3 -m pip install --upgrade pip wheel
-RUN pip install --no-cache '.[export]' albumentations gsutil notebook \
+RUN pip install --no-cache . albumentations gsutil notebook \
         --extra-index-url https://download.pytorch.org/whl/cpu
 
 # Cleanup
diff --git a/docs/usage/python.md b/docs/usage/python.md
index a4f69448..867835e6 100644
--- a/docs/usage/python.md
+++ b/docs/usage/python.md
@@ -139,7 +139,7 @@ predicts the classes and locations of objects in the input images or videos.
         results = model.predict(source=0, stream=True)
 
         for result in results:
-            # detection
+            # Detection
             result.boxes.xyxy   # box with xyxy format, (N, 4)
             result.boxes.xywh   # box with xywh format, (N, 4)
             result.boxes.xyxyn  # box with xyxy format but normalized, (N, 4)
@@ -147,12 +147,12 @@ predicts the classes and locations of objects in the input images or videos.
             result.boxes.conf   # confidence score, (N, 1)
             result.boxes.cls    # cls, (N, 1)
 
-            # segmentation
-            result.masks.masks     # masks, (N, H, W)
+            # Segmentation
+            result.masks.data      # masks, (N, H, W)
             result.masks.xy        # x,y segments (pixels), List[segment] * N
             result.masks.xyn       # x,y segments (normalized), List[segment] * N
 
-            # classification
+            # Classification
             result.probs     # cls prob, (num_class, )
 
         # Each result is composed of torch.Tensor by default, 
diff --git a/setup.py b/setup.py
index ba0296ef..a0c44c33 100644
--- a/setup.py
+++ b/setup.py
@@ -39,8 +39,9 @@ setup(
     install_requires=REQUIREMENTS + PKG_REQUIREMENTS,
     extras_require={
         'dev': ['check-manifest', 'pytest', 'pytest-cov', 'coverage', 'mkdocs-material', 'mkdocstrings[python]'],
-        'export': ['coremltools>=6.0', 'onnx', 'onnxsim', 'onnxruntime', 'openvino-dev>=2022.3'],
-        'tf': ['onnx2tf', 'sng4onnx', 'tflite_support', 'tensorflow']},
+        'export': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflow', 'tensorflowjs'],
+        'export-cpu': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflow-cpu', 'tensorflowjs'],
+        'export-macos': ['coremltools>=6.0', 'openvino-dev>=2022.3', 'tensorflow-macos', 'tensorflowjs']},
     classifiers=[
         'Development Status :: 4 - Beta',
         'Intended Audience :: Developers',
diff --git a/tests/test_python.py b/tests/test_python.py
index 2a50dc23..d71b47e6 100644
--- a/tests/test_python.py
+++ b/tests/test_python.py
@@ -217,7 +217,7 @@ def test_result():
     res[0].plot(conf=True, boxes=False, masks=True)
     res[0].plot(pil=True)
     res[0] = res[0].cpu().numpy()
-    print(res[0].path, res[0].masks.masks)
+    print(res[0].path, res[0].masks.data)
 
     model = YOLO('yolov8n.pt')
     res = model(SOURCE)
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 37cef40a..cf94ef85 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
 # Ultralytics YOLO 🚀, GPL-3.0 license
 
-__version__ = '8.0.75'
+__version__ = '8.0.76'
 
 from ultralytics.hub import start
 from ultralytics.yolo.engine.model import YOLO
diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py
index a2aa6f53..542d57d9 100644
--- a/ultralytics/hub/session.py
+++ b/ultralytics/hub/session.py
@@ -8,6 +8,7 @@ import requests
 
 from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, check_dataset_disk_space, smart_request
 from ultralytics.yolo.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
+from ultralytics.yolo.utils.errors import HUBModelError
 
 AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
 
@@ -55,7 +56,8 @@ class HUBTrainingSession:
         elif len(url) == 20:
             key, model_id = '', url
         else:
-            raise ValueError(f'Invalid HUBTrainingSession input: {url}')
+            raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. "
+                                f"model='https://hub.ultralytics.com/models/MODEL_ID' and try again.")
 
         # Authorize
         auth = Auth(key)
diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py
index 26e19fd3..6361722a 100644
--- a/ultralytics/yolo/engine/model.py
+++ b/ultralytics/yolo/engine/model.py
@@ -116,7 +116,7 @@ class YOLO:
     @staticmethod
     def is_hub_model(model):
         return any((
-            model.startswith('https://hub.ultralytics.com/models/'),
+            model.startswith('https://hub.ultra'),  # i.e. https://hub.ultralytics.com/models/MODEL_ID
             [len(x) for x in model.split('_')] == [42, 20],  # APIKEY_MODELID
             len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\')))  # MODELID
 
diff --git a/ultralytics/yolo/engine/results.py b/ultralytics/yolo/engine/results.py
index b47448bd..9b0fb0fe 100644
--- a/ultralytics/yolo/engine/results.py
+++ b/ultralytics/yolo/engine/results.py
@@ -207,7 +207,7 @@ class Results(SimpleClass):
         if pred_masks and show_masks:
             if img_gpu is None:
                 img = LetterBox(pred_masks.shape[1:])(image=annotator.result())
-                img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.masks.device).permute(
+                img_gpu = torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device).permute(
                     2, 0, 1).flip(0).contiguous() / 255
             annotator.masks(pred_masks.data, colors=[colors(x, True) for x in pred_boxes.cls], im_gpu=img_gpu)
 
diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py
index 0100d9ce..b920465e 100644
--- a/ultralytics/yolo/engine/trainer.py
+++ b/ultralytics/yolo/engine/trainer.py
@@ -632,9 +632,9 @@ def check_amp(model):
 
     def amp_allclose(m, im):
         # All close FP32 vs AMP results
-        a = m(im, device=device, verbose=False)[0].boxes.boxes  # FP32 inference
+        a = m(im, device=device, verbose=False)[0].boxes.data  # FP32 inference
         with torch.cuda.amp.autocast(True):
-            b = m(im, device=device, verbose=False)[0].boxes.boxes  # AMP inference
+            b = m(im, device=device, verbose=False)[0].boxes.data  # AMP inference
         del m
         return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5)  # close to 0.5 absolute tolerance
 
diff --git a/ultralytics/yolo/utils/errors.py b/ultralytics/yolo/utils/errors.py
new file mode 100644
index 00000000..71027359
--- /dev/null
+++ b/ultralytics/yolo/utils/errors.py
@@ -0,0 +1,9 @@
+# Ultralytics YOLO 🚀, GPL-3.0 license
+
+from ultralytics.yolo.utils import emojis
+
+
+class HUBModelError(Exception):
+
+    def __init__(self, message='Model not found. Please check model URL and try again.'):
+        super().__init__(emojis(message))
diff --git a/ultralytics/yolo/utils/metrics.py b/ultralytics/yolo/utils/metrics.py
index d00d1c73..fcba4aa5 100644
--- a/ultralytics/yolo/utils/metrics.py
+++ b/ultralytics/yolo/utils/metrics.py
@@ -172,19 +172,37 @@ class FocalLoss(nn.Module):
 
 class ConfusionMatrix:
     # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
-    def __init__(self, nc, conf=0.25, iou_thres=0.45):
-        self.matrix = np.zeros((nc + 1, nc + 1))
+    def __init__(self, nc, conf=0.25, iou_thres=0.45, task='detect'):
+        self.task = task
+        self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc))
         self.nc = nc  # number of classes
         self.conf = conf
         self.iou_thres = iou_thres
 
+    def process_cls_preds(self, preds, targets):
+        """
+        Update confusion matrix for classification task
+
+        Arguments:
+            preds (Array[N, min(nc,5)])
+            targets (Array[N, 1])
+
+        Returns:
+            None, updates confusion matrix accordingly
+        """
+        preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
+        for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
+            self.matrix[t][p] += 1
+
     def process_batch(self, detections, labels):
         """
         Return intersection-over-union (Jaccard index) of boxes.
         Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
+
         Arguments:
             detections (Array[N, 6]), x1, y1, x2, y2, conf, class
             labels (Array[M, 5]), class, x1, y1, x2, y2
+
         Returns:
             None, updates confusion matrix accordingly
         """
@@ -231,7 +249,7 @@ class ConfusionMatrix:
         tp = self.matrix.diagonal()  # true positives
         fp = self.matrix.sum(1) - tp  # false positives
         # fn = self.matrix.sum(0) - tp  # false negatives (missed detections)
-        return tp[:-1], fp[:-1]  # remove background class
+        return (tp[:-1], fp[:-1]) if self.task == 'detect' else (tp, fp)  # remove background class if task=detect
 
     @TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
     @plt_settings()
diff --git a/ultralytics/yolo/utils/ops.py b/ultralytics/yolo/utils/ops.py
index a59197f8..04fb8837 100644
--- a/ultralytics/yolo/utils/ops.py
+++ b/ultralytics/yolo/utils/ops.py
@@ -547,9 +547,9 @@ def crop_mask(masks, boxes):
       (torch.Tensor): The masks are being cropped to the bounding box.
     """
     n, h, w = masks.shape
-    x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1)  # x1 shape(1,1,n)
-    r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :]  # rows shape(1,w,1)
-    c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None]  # cols shape(h,1,1)
+    x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1)  # x1 shape(n,1,1)
+    r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :]  # rows shape(1,1,w)
+    c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None]  # cols shape(1,h,1)
 
     return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
 
diff --git a/ultralytics/yolo/v8/classify/val.py b/ultralytics/yolo/v8/classify/val.py
index aede667c..e1b359d6 100644
--- a/ultralytics/yolo/v8/classify/val.py
+++ b/ultralytics/yolo/v8/classify/val.py
@@ -3,7 +3,7 @@
 from ultralytics.yolo.data import build_classification_dataloader
 from ultralytics.yolo.engine.validator import BaseValidator
 from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER
-from ultralytics.yolo.utils.metrics import ClassifyMetrics
+from ultralytics.yolo.utils.metrics import ClassifyMetrics, ConfusionMatrix
 
 
 class ClassificationValidator(BaseValidator):
@@ -12,11 +12,15 @@ class ClassificationValidator(BaseValidator):
         super().__init__(dataloader, save_dir, pbar, args, _callbacks)
         self.args.task = 'classify'
         self.metrics = ClassifyMetrics()
+        self.save_dir = save_dir
 
     def get_desc(self):
         return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
 
     def init_metrics(self, model):
+        self.names = model.names
+        self.nc = len(model.names)
+        self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify')
         self.pred = []
         self.targets = []
 
@@ -32,8 +36,9 @@ class ClassificationValidator(BaseValidator):
         self.targets.append(batch['cls'])
 
     def finalize_metrics(self, *args, **kwargs):
+        self.confusion_matrix.process_cls_preds(self.pred, self.targets)
         self.metrics.speed = self.speed
-        # self.metrics.confusion_matrix = self.confusion_matrix  # TODO: classification ConfusionMatrix
+        self.metrics.confusion_matrix = self.confusion_matrix
 
     def get_stats(self):
         self.metrics.process(self.targets, self.pred)
@@ -50,6 +55,8 @@ class ClassificationValidator(BaseValidator):
     def print_results(self):
         pf = '%22s' + '%11.3g' * len(self.metrics.keys)  # print format
         LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
+        if self.args.plots:
+            self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
 
 
 def val(cfg=DEFAULT_CFG, use_python=False):