Object Counter improvements (#8648)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Henry 2024-03-05 19:41:16 +01:00 committed by GitHub
parent ddc94a6981
commit 3596a77a5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -38,11 +38,12 @@ class ObjectCounter:
self.names = None # Classes names self.names = None # Classes names
self.annotator = None # Annotator self.annotator = None # Annotator
self.window_name = "Ultralytics YOLOv8 Object Counter"
# Object counting Information # Object counting Information
self.in_counts = 0 self.in_counts = 0
self.out_counts = 0 self.out_counts = 0
self.counting_list = [] self.counting_dict = {}
self.count_txt_thickness = 0 self.count_txt_thickness = 0
self.count_txt_color = (0, 0, 0) self.count_txt_color = (0, 0, 0)
self.count_color = (255, 255, 255) self.count_color = (255, 255, 255)
@ -106,12 +107,12 @@ class ObjectCounter:
print("Line Counter Initiated.") print("Line Counter Initiated.")
self.reg_pts = reg_pts self.reg_pts = reg_pts
self.counting_region = LineString(self.reg_pts) self.counting_region = LineString(self.reg_pts)
elif len(reg_pts) == 4: elif len(reg_pts) >= 3:
print("Region Counter Initiated.") print("Region Counter Initiated.")
self.reg_pts = reg_pts self.reg_pts = reg_pts
self.counting_region = Polygon(self.reg_pts) self.counting_region = Polygon(self.reg_pts)
else: else:
print("Invalid Region points provided, region_points can be 2 or 4") print("Invalid Region points provided, region_points must be 2 for lines or >= 3 for polygons.")
print("Using Line Counter Now") print("Using Line Counter Now")
self.counting_region = LineString(self.reg_pts) self.counting_region = LineString(self.reg_pts)
@ -158,13 +159,14 @@ class ObjectCounter:
def extract_and_process_tracks(self, tracks): def extract_and_process_tracks(self, tracks):
"""Extracts and processes tracks for object counting in a video stream.""" """Extracts and processes tracks for object counting in a video stream."""
boxes = tracks[0].boxes.xyxy.cpu()
clss = tracks[0].boxes.cls.cpu().tolist()
track_ids = tracks[0].boxes.id.int().cpu().tolist()
# Annotator Init and region drawing # Annotator Init and region drawing
self.annotator = Annotator(self.im0, self.tf, self.names) self.annotator = Annotator(self.im0, self.tf, self.names)
self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
if tracks[0].boxes.id is not None:
boxes = tracks[0].boxes.xyxy.cpu()
clss = tracks[0].boxes.cls.cpu().tolist()
track_ids = tracks[0].boxes.id.int().cpu().tolist()
# Extract tracks # Extract tracks
for box, track_id, cls in zip(boxes, track_ids, clss): for box, track_id, cls in zip(boxes, track_ids, clss):
@ -184,29 +186,43 @@ class ObjectCounter:
) )
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
centroid = Point((box[:2] + box[2:]) / 2)
# Count objects # Count objects
if len(self.reg_pts) == 4: if len(self.reg_pts) >= 3: # any polygon
if ( is_inside = self.counting_region.contains(centroid)
prev_position is not None current_position = "in" if is_inside else "out"
and self.counting_region.contains(Point(track_line[-1]))
and track_id not in self.counting_list if prev_position is not None:
): if self.counting_dict[track_id] != current_position and is_inside:
self.counting_list.append(track_id)
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
self.in_counts += 1 self.in_counts += 1
else: self.counting_dict[track_id] = "in"
elif self.counting_dict[track_id] != current_position and not is_inside:
self.out_counts += 1 self.out_counts += 1
self.counting_dict[track_id] = "out"
else:
self.counting_dict[track_id] = current_position
else:
self.counting_dict[track_id] = current_position
elif len(self.reg_pts) == 2: elif len(self.reg_pts) == 2:
if prev_position is not None: if prev_position is not None:
distance = Point(track_line[-1]).distance(self.counting_region) is_inside = (box[0] - prev_position[0]) * (
if distance < self.line_dist_thresh and track_id not in self.counting_list: self.counting_region.centroid.x - prev_position[0]
self.counting_list.append(track_id) ) > 0
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0: current_position = "in" if is_inside else "out"
if self.counting_dict[track_id] != current_position and is_inside:
self.in_counts += 1 self.in_counts += 1
else: self.counting_dict[track_id] = "in"
elif self.counting_dict[track_id] != current_position and not is_inside:
self.out_counts += 1 self.out_counts += 1
self.counting_dict[track_id] = "out"
else:
self.counting_dict[track_id] = current_position
else:
self.counting_dict[track_id] = None
incount_label = f"In Count : {self.in_counts}" incount_label = f"In Count : {self.in_counts}"
outcount_label = f"OutCount : {self.out_counts}" outcount_label = f"OutCount : {self.out_counts}"
@ -233,12 +249,11 @@ class ObjectCounter:
def display_frames(self): def display_frames(self):
"""Display frame.""" """Display frame."""
if self.env_check: if self.env_check:
cv2.namedWindow("Ultralytics YOLOv8 Object Counter") self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
cv2.namedWindow(self.window_name)
if len(self.reg_pts) == 4: # only add mouse event If user drawn region if len(self.reg_pts) == 4: # only add mouse event If user drawn region
cv2.setMouseCallback( cv2.setMouseCallback(self.window_name, self.mouse_event_for_region, {"region_points": self.reg_pts})
"Ultralytics YOLOv8 Object Counter", self.mouse_event_for_region, {"region_points": self.reg_pts} cv2.imshow(self.window_name, self.im0)
)
cv2.imshow("Ultralytics YOLOv8 Object Counter", self.im0)
# Break Window # Break Window
if cv2.waitKey(1) & 0xFF == ord("q"): if cv2.waitKey(1) & 0xFF == ord("q"):
return return
@ -252,12 +267,7 @@ class ObjectCounter:
tracks (list): List of tracks obtained from the object tracking process. tracks (list): List of tracks obtained from the object tracking process.
""" """
self.im0 = im0 # store image self.im0 = im0 # store image
self.extract_and_process_tracks(tracks) # draw region even if no objects
if tracks[0].boxes.id is None:
if self.view_img:
self.display_frames()
return im0
self.extract_and_process_tracks(tracks)
if self.view_img: if self.view_img:
self.display_frames() self.display_frames()