mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Object Counter improvements (#8648)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
ddc94a6981
commit
3596a77a5a
@ -38,11 +38,12 @@ class ObjectCounter:
|
||||
|
||||
self.names = None # Classes names
|
||||
self.annotator = None # Annotator
|
||||
self.window_name = "Ultralytics YOLOv8 Object Counter"
|
||||
|
||||
# Object counting Information
|
||||
self.in_counts = 0
|
||||
self.out_counts = 0
|
||||
self.counting_list = []
|
||||
self.counting_dict = {}
|
||||
self.count_txt_thickness = 0
|
||||
self.count_txt_color = (0, 0, 0)
|
||||
self.count_color = (255, 255, 255)
|
||||
@ -106,12 +107,12 @@ class ObjectCounter:
|
||||
print("Line Counter Initiated.")
|
||||
self.reg_pts = reg_pts
|
||||
self.counting_region = LineString(self.reg_pts)
|
||||
elif len(reg_pts) == 4:
|
||||
elif len(reg_pts) >= 3:
|
||||
print("Region Counter Initiated.")
|
||||
self.reg_pts = reg_pts
|
||||
self.counting_region = Polygon(self.reg_pts)
|
||||
else:
|
||||
print("Invalid Region points provided, region_points can be 2 or 4")
|
||||
print("Invalid Region points provided, region_points must be 2 for lines or >= 3 for polygons.")
|
||||
print("Using Line Counter Now")
|
||||
self.counting_region = LineString(self.reg_pts)
|
||||
|
||||
@ -158,13 +159,14 @@ class ObjectCounter:
|
||||
|
||||
def extract_and_process_tracks(self, tracks):
|
||||
"""Extracts and processes tracks for object counting in a video stream."""
|
||||
boxes = tracks[0].boxes.xyxy.cpu()
|
||||
clss = tracks[0].boxes.cls.cpu().tolist()
|
||||
track_ids = tracks[0].boxes.id.int().cpu().tolist()
|
||||
|
||||
# Annotator Init and region drawing
|
||||
self.annotator = Annotator(self.im0, self.tf, self.names)
|
||||
self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
|
||||
|
||||
if tracks[0].boxes.id is not None:
|
||||
boxes = tracks[0].boxes.xyxy.cpu()
|
||||
clss = tracks[0].boxes.cls.cpu().tolist()
|
||||
track_ids = tracks[0].boxes.id.int().cpu().tolist()
|
||||
|
||||
# Extract tracks
|
||||
for box, track_id, cls in zip(boxes, track_ids, clss):
|
||||
@ -184,29 +186,43 @@ class ObjectCounter:
|
||||
)
|
||||
|
||||
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
|
||||
centroid = Point((box[:2] + box[2:]) / 2)
|
||||
|
||||
# Count objects
|
||||
if len(self.reg_pts) == 4:
|
||||
if (
|
||||
prev_position is not None
|
||||
and self.counting_region.contains(Point(track_line[-1]))
|
||||
and track_id not in self.counting_list
|
||||
):
|
||||
self.counting_list.append(track_id)
|
||||
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
|
||||
if len(self.reg_pts) >= 3: # any polygon
|
||||
is_inside = self.counting_region.contains(centroid)
|
||||
current_position = "in" if is_inside else "out"
|
||||
|
||||
if prev_position is not None:
|
||||
if self.counting_dict[track_id] != current_position and is_inside:
|
||||
self.in_counts += 1
|
||||
else:
|
||||
self.counting_dict[track_id] = "in"
|
||||
elif self.counting_dict[track_id] != current_position and not is_inside:
|
||||
self.out_counts += 1
|
||||
self.counting_dict[track_id] = "out"
|
||||
else:
|
||||
self.counting_dict[track_id] = current_position
|
||||
|
||||
else:
|
||||
self.counting_dict[track_id] = current_position
|
||||
|
||||
elif len(self.reg_pts) == 2:
|
||||
if prev_position is not None:
|
||||
distance = Point(track_line[-1]).distance(self.counting_region)
|
||||
if distance < self.line_dist_thresh and track_id not in self.counting_list:
|
||||
self.counting_list.append(track_id)
|
||||
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
|
||||
is_inside = (box[0] - prev_position[0]) * (
|
||||
self.counting_region.centroid.x - prev_position[0]
|
||||
) > 0
|
||||
current_position = "in" if is_inside else "out"
|
||||
|
||||
if self.counting_dict[track_id] != current_position and is_inside:
|
||||
self.in_counts += 1
|
||||
else:
|
||||
self.counting_dict[track_id] = "in"
|
||||
elif self.counting_dict[track_id] != current_position and not is_inside:
|
||||
self.out_counts += 1
|
||||
self.counting_dict[track_id] = "out"
|
||||
else:
|
||||
self.counting_dict[track_id] = current_position
|
||||
else:
|
||||
self.counting_dict[track_id] = None
|
||||
|
||||
incount_label = f"In Count : {self.in_counts}"
|
||||
outcount_label = f"OutCount : {self.out_counts}"
|
||||
@ -233,12 +249,11 @@ class ObjectCounter:
|
||||
def display_frames(self):
|
||||
"""Display frame."""
|
||||
if self.env_check:
|
||||
cv2.namedWindow("Ultralytics YOLOv8 Object Counter")
|
||||
self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
|
||||
cv2.namedWindow(self.window_name)
|
||||
if len(self.reg_pts) == 4: # only add mouse event If user drawn region
|
||||
cv2.setMouseCallback(
|
||||
"Ultralytics YOLOv8 Object Counter", self.mouse_event_for_region, {"region_points": self.reg_pts}
|
||||
)
|
||||
cv2.imshow("Ultralytics YOLOv8 Object Counter", self.im0)
|
||||
cv2.setMouseCallback(self.window_name, self.mouse_event_for_region, {"region_points": self.reg_pts})
|
||||
cv2.imshow(self.window_name, self.im0)
|
||||
# Break Window
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
return
|
||||
@ -252,12 +267,7 @@ class ObjectCounter:
|
||||
tracks (list): List of tracks obtained from the object tracking process.
|
||||
"""
|
||||
self.im0 = im0 # store image
|
||||
|
||||
if tracks[0].boxes.id is None:
|
||||
if self.view_img:
|
||||
self.display_frames()
|
||||
return im0
|
||||
self.extract_and_process_tracks(tracks)
|
||||
self.extract_and_process_tracks(tracks) # draw region even if no objects
|
||||
|
||||
if self.view_img:
|
||||
self.display_frames()
|
||||
|
Loading…
x
Reference in New Issue
Block a user