Object Counter improvements (#8648)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Henry 2024-03-05 19:41:16 +01:00 committed by GitHub
parent ddc94a6981
commit 3596a77a5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -38,11 +38,12 @@ class ObjectCounter:
self.names = None # Classes names
self.annotator = None # Annotator
self.window_name = "Ultralytics YOLOv8 Object Counter"
# Object counting Information
self.in_counts = 0
self.out_counts = 0
self.counting_list = []
self.counting_dict = {}
self.count_txt_thickness = 0
self.count_txt_color = (0, 0, 0)
self.count_color = (255, 255, 255)
@ -106,12 +107,12 @@ class ObjectCounter:
print("Line Counter Initiated.")
self.reg_pts = reg_pts
self.counting_region = LineString(self.reg_pts)
elif len(reg_pts) == 4:
elif len(reg_pts) >= 3:
print("Region Counter Initiated.")
self.reg_pts = reg_pts
self.counting_region = Polygon(self.reg_pts)
else:
print("Invalid Region points provided, region_points can be 2 or 4")
print("Invalid Region points provided, region_points must be 2 for lines or >= 3 for polygons.")
print("Using Line Counter Now")
self.counting_region = LineString(self.reg_pts)
@ -158,55 +159,70 @@ class ObjectCounter:
def extract_and_process_tracks(self, tracks):
"""Extracts and processes tracks for object counting in a video stream."""
boxes = tracks[0].boxes.xyxy.cpu()
clss = tracks[0].boxes.cls.cpu().tolist()
track_ids = tracks[0].boxes.id.int().cpu().tolist()
# Annotator Init and region drawing
self.annotator = Annotator(self.im0, self.tf, self.names)
self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
# Extract tracks
for box, track_id, cls in zip(boxes, track_ids, clss):
# Draw bounding box
self.annotator.box_label(box, label=f"{track_id}:{self.names[cls]}", color=colors(int(cls), True))
if tracks[0].boxes.id is not None:
boxes = tracks[0].boxes.xyxy.cpu()
clss = tracks[0].boxes.cls.cpu().tolist()
track_ids = tracks[0].boxes.id.int().cpu().tolist()
# Draw Tracks
track_line = self.track_history[track_id]
track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2)))
if len(track_line) > 30:
track_line.pop(0)
# Extract tracks
for box, track_id, cls in zip(boxes, track_ids, clss):
# Draw bounding box
self.annotator.box_label(box, label=f"{track_id}:{self.names[cls]}", color=colors(int(cls), True))
# Draw track trails
if self.draw_tracks:
self.annotator.draw_centroid_and_tracks(
track_line, color=self.track_color, track_thickness=self.track_thickness
)
# Draw Tracks
track_line = self.track_history[track_id]
track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2)))
if len(track_line) > 30:
track_line.pop(0)
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
# Draw track trails
if self.draw_tracks:
self.annotator.draw_centroid_and_tracks(
track_line, color=self.track_color, track_thickness=self.track_thickness
)
# Count objects
if len(self.reg_pts) == 4:
if (
prev_position is not None
and self.counting_region.contains(Point(track_line[-1]))
and track_id not in self.counting_list
):
self.counting_list.append(track_id)
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
self.in_counts += 1
else:
self.out_counts += 1
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
centroid = Point((box[:2] + box[2:]) / 2)
elif len(self.reg_pts) == 2:
if prev_position is not None:
distance = Point(track_line[-1]).distance(self.counting_region)
if distance < self.line_dist_thresh and track_id not in self.counting_list:
self.counting_list.append(track_id)
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
# Count objects
if len(self.reg_pts) >= 3: # any polygon
is_inside = self.counting_region.contains(centroid)
current_position = "in" if is_inside else "out"
if prev_position is not None:
if self.counting_dict[track_id] != current_position and is_inside:
self.in_counts += 1
else:
self.counting_dict[track_id] = "in"
elif self.counting_dict[track_id] != current_position and not is_inside:
self.out_counts += 1
self.counting_dict[track_id] = "out"
else:
self.counting_dict[track_id] = current_position
else:
self.counting_dict[track_id] = current_position
elif len(self.reg_pts) == 2:
if prev_position is not None:
is_inside = (box[0] - prev_position[0]) * (
self.counting_region.centroid.x - prev_position[0]
) > 0
current_position = "in" if is_inside else "out"
if self.counting_dict[track_id] != current_position and is_inside:
self.in_counts += 1
self.counting_dict[track_id] = "in"
elif self.counting_dict[track_id] != current_position and not is_inside:
self.out_counts += 1
self.counting_dict[track_id] = "out"
else:
self.counting_dict[track_id] = current_position
else:
self.counting_dict[track_id] = None
incount_label = f"In Count : {self.in_counts}"
outcount_label = f"OutCount : {self.out_counts}"
@ -233,12 +249,11 @@ class ObjectCounter:
def display_frames(self):
"""Display frame."""
if self.env_check:
cv2.namedWindow("Ultralytics YOLOv8 Object Counter")
self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
cv2.namedWindow(self.window_name)
if len(self.reg_pts) == 4: # only add mouse event If user drawn region
cv2.setMouseCallback(
"Ultralytics YOLOv8 Object Counter", self.mouse_event_for_region, {"region_points": self.reg_pts}
)
cv2.imshow("Ultralytics YOLOv8 Object Counter", self.im0)
cv2.setMouseCallback(self.window_name, self.mouse_event_for_region, {"region_points": self.reg_pts})
cv2.imshow(self.window_name, self.im0)
# Break Window
if cv2.waitKey(1) & 0xFF == ord("q"):
return
@ -252,12 +267,7 @@ class ObjectCounter:
tracks (list): List of tracks obtained from the object tracking process.
"""
self.im0 = im0 # store image
if tracks[0].boxes.id is None:
if self.view_img:
self.display_frames()
return im0
self.extract_and_process_tracks(tracks)
self.extract_and_process_tracks(tracks) # draw region even if no objects
if self.view_img:
self.display_frames()