From 5be2ffbd13e278f108e8f37ba439d9c51e7c5822 Mon Sep 17 00:00:00 2001
From: Glenn Jocher <glenn.jocher@ultralytics.com>
Date: Mon, 25 Mar 2024 02:00:38 +0100
Subject: [PATCH] `ultralytics 8.1.34` Inference API robust imgsz checks
 (#9274)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
---
 ultralytics/__init__.py       |  2 +-
 ultralytics/engine/results.py | 32 +++++++++++++++++++++-----------
 ultralytics/utils/checks.py   |  2 ++
 3 files changed, 24 insertions(+), 12 deletions(-)

diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 9a5b1ee3..d25836a0 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
 # Ultralytics YOLO 🚀, AGPL-3.0 license
 
-__version__ = "8.1.33"
+__version__ = "8.1.34"
 
 from ultralytics.data.explorer.explorer import Explorer
 from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py
index babe5d3b..85849c34 100644
--- a/ultralytics/engine/results.py
+++ b/ultralytics/engine/results.py
@@ -385,10 +385,10 @@ class Results(SimpleClass):
                 BGR=True,
             )
 
-    def summary(self, normalize=False):
+    def summary(self, normalize=False, decimals=5):
         """Convert the results to a summarized format."""
         if self.probs is not None:
-            LOGGER.warning("Warning: Classify task do not support `summary` and `tojson` yet.")
+            LOGGER.warning("Warning: Classify results do not support the `summary()` method yet.")
             return
 
         # Create list of detection dictionaries
@@ -396,28 +396,38 @@ class Results(SimpleClass):
         data = self.boxes.data.cpu().tolist()
         h, w = self.orig_shape if normalize else (1, 1)
         for i, row in enumerate(data):  # xyxy, track_id if tracking, conf, class_id
-            box = {"x1": row[0] / w, "y1": row[1] / h, "x2": row[2] / w, "y2": row[3] / h}
-            conf = row[-2]
+            box = {
+                "x1": round(row[0] / w, decimals),
+                "y1": round(row[1] / h, decimals),
+                "x2": round(row[2] / w, decimals),
+                "y2": round(row[3] / h, decimals),
+            }
+            conf = round(row[-2], decimals)
             class_id = int(row[-1])
-            name = self.names[class_id]
-            result = {"name": name, "class": class_id, "confidence": conf, "box": box}
+            result = {"name": self.names[class_id], "class": class_id, "confidence": conf, "box": box}
             if self.boxes.is_track:
                 result["track_id"] = int(row[-3])  # track ID
             if self.masks:
-                x, y = self.masks.xy[i][:, 0], self.masks.xy[i][:, 1]  # numpy array
-                result["segments"] = {"x": (x / w).tolist(), "y": (y / h).tolist()}
+                result["segments"] = {
+                    "x": (self.masks.xy[i][:, 0] / w).round(decimals).tolist(),
+                    "y": (self.masks.xy[i][:, 1] / h).round(decimals).tolist(),
+                }
             if self.keypoints is not None:
                 x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1)  # torch Tensor
-                result["keypoints"] = {"x": (x / w).tolist(), "y": (y / h).tolist(), "visible": visible.tolist()}
+                result["keypoints"] = {
+                    "x": (x / w).numpy().round(decimals).tolist(),  # decimals named argument required
+                    "y": (y / h).numpy().round(decimals).tolist(),
+                    "visible": visible.numpy().round(decimals).tolist(),
+                }
             results.append(result)
 
         return results
 
-    def tojson(self, normalize=False):
+    def tojson(self, normalize=False, decimals=5):
         """Convert the results to JSON format."""
         import json
 
-        return json.dumps(self.summary(normalize=normalize), indent=2)
+        return json.dumps(self.summary(normalize=normalize, decimals=decimals), indent=2)
 
 
 class Boxes(BaseTensor):
diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py
index 5d908851..c44ac0b6 100644
--- a/ultralytics/utils/checks.py
+++ b/ultralytics/utils/checks.py
@@ -142,6 +142,8 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
         imgsz = [imgsz]
     elif isinstance(imgsz, (list, tuple)):
         imgsz = list(imgsz)
+    elif isinstance(imgsz, str):  # i.e. '640' or '[640,640]'
+        imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz)
     else:
         raise TypeError(
             f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "