mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Object Counter improvements (#8648)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
ddc94a6981
commit
3596a77a5a
@ -38,11 +38,12 @@ class ObjectCounter:
|
|||||||
|
|
||||||
self.names = None # Classes names
|
self.names = None # Classes names
|
||||||
self.annotator = None # Annotator
|
self.annotator = None # Annotator
|
||||||
|
self.window_name = "Ultralytics YOLOv8 Object Counter"
|
||||||
|
|
||||||
# Object counting Information
|
# Object counting Information
|
||||||
self.in_counts = 0
|
self.in_counts = 0
|
||||||
self.out_counts = 0
|
self.out_counts = 0
|
||||||
self.counting_list = []
|
self.counting_dict = {}
|
||||||
self.count_txt_thickness = 0
|
self.count_txt_thickness = 0
|
||||||
self.count_txt_color = (0, 0, 0)
|
self.count_txt_color = (0, 0, 0)
|
||||||
self.count_color = (255, 255, 255)
|
self.count_color = (255, 255, 255)
|
||||||
@ -106,12 +107,12 @@ class ObjectCounter:
|
|||||||
print("Line Counter Initiated.")
|
print("Line Counter Initiated.")
|
||||||
self.reg_pts = reg_pts
|
self.reg_pts = reg_pts
|
||||||
self.counting_region = LineString(self.reg_pts)
|
self.counting_region = LineString(self.reg_pts)
|
||||||
elif len(reg_pts) == 4:
|
elif len(reg_pts) >= 3:
|
||||||
print("Region Counter Initiated.")
|
print("Region Counter Initiated.")
|
||||||
self.reg_pts = reg_pts
|
self.reg_pts = reg_pts
|
||||||
self.counting_region = Polygon(self.reg_pts)
|
self.counting_region = Polygon(self.reg_pts)
|
||||||
else:
|
else:
|
||||||
print("Invalid Region points provided, region_points can be 2 or 4")
|
print("Invalid Region points provided, region_points must be 2 for lines or >= 3 for polygons.")
|
||||||
print("Using Line Counter Now")
|
print("Using Line Counter Now")
|
||||||
self.counting_region = LineString(self.reg_pts)
|
self.counting_region = LineString(self.reg_pts)
|
||||||
|
|
||||||
@ -158,55 +159,70 @@ class ObjectCounter:
|
|||||||
|
|
||||||
def extract_and_process_tracks(self, tracks):
|
def extract_and_process_tracks(self, tracks):
|
||||||
"""Extracts and processes tracks for object counting in a video stream."""
|
"""Extracts and processes tracks for object counting in a video stream."""
|
||||||
boxes = tracks[0].boxes.xyxy.cpu()
|
|
||||||
clss = tracks[0].boxes.cls.cpu().tolist()
|
|
||||||
track_ids = tracks[0].boxes.id.int().cpu().tolist()
|
|
||||||
|
|
||||||
# Annotator Init and region drawing
|
# Annotator Init and region drawing
|
||||||
self.annotator = Annotator(self.im0, self.tf, self.names)
|
self.annotator = Annotator(self.im0, self.tf, self.names)
|
||||||
self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
|
|
||||||
|
|
||||||
# Extract tracks
|
if tracks[0].boxes.id is not None:
|
||||||
for box, track_id, cls in zip(boxes, track_ids, clss):
|
boxes = tracks[0].boxes.xyxy.cpu()
|
||||||
# Draw bounding box
|
clss = tracks[0].boxes.cls.cpu().tolist()
|
||||||
self.annotator.box_label(box, label=f"{track_id}:{self.names[cls]}", color=colors(int(cls), True))
|
track_ids = tracks[0].boxes.id.int().cpu().tolist()
|
||||||
|
|
||||||
# Draw Tracks
|
# Extract tracks
|
||||||
track_line = self.track_history[track_id]
|
for box, track_id, cls in zip(boxes, track_ids, clss):
|
||||||
track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2)))
|
# Draw bounding box
|
||||||
if len(track_line) > 30:
|
self.annotator.box_label(box, label=f"{track_id}:{self.names[cls]}", color=colors(int(cls), True))
|
||||||
track_line.pop(0)
|
|
||||||
|
|
||||||
# Draw track trails
|
# Draw Tracks
|
||||||
if self.draw_tracks:
|
track_line = self.track_history[track_id]
|
||||||
self.annotator.draw_centroid_and_tracks(
|
track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2)))
|
||||||
track_line, color=self.track_color, track_thickness=self.track_thickness
|
if len(track_line) > 30:
|
||||||
)
|
track_line.pop(0)
|
||||||
|
|
||||||
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
|
# Draw track trails
|
||||||
|
if self.draw_tracks:
|
||||||
|
self.annotator.draw_centroid_and_tracks(
|
||||||
|
track_line, color=self.track_color, track_thickness=self.track_thickness
|
||||||
|
)
|
||||||
|
|
||||||
# Count objects
|
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
|
||||||
if len(self.reg_pts) == 4:
|
centroid = Point((box[:2] + box[2:]) / 2)
|
||||||
if (
|
|
||||||
prev_position is not None
|
|
||||||
and self.counting_region.contains(Point(track_line[-1]))
|
|
||||||
and track_id not in self.counting_list
|
|
||||||
):
|
|
||||||
self.counting_list.append(track_id)
|
|
||||||
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
|
|
||||||
self.in_counts += 1
|
|
||||||
else:
|
|
||||||
self.out_counts += 1
|
|
||||||
|
|
||||||
elif len(self.reg_pts) == 2:
|
# Count objects
|
||||||
if prev_position is not None:
|
if len(self.reg_pts) >= 3: # any polygon
|
||||||
distance = Point(track_line[-1]).distance(self.counting_region)
|
is_inside = self.counting_region.contains(centroid)
|
||||||
if distance < self.line_dist_thresh and track_id not in self.counting_list:
|
current_position = "in" if is_inside else "out"
|
||||||
self.counting_list.append(track_id)
|
|
||||||
if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
|
if prev_position is not None:
|
||||||
|
if self.counting_dict[track_id] != current_position and is_inside:
|
||||||
self.in_counts += 1
|
self.in_counts += 1
|
||||||
else:
|
self.counting_dict[track_id] = "in"
|
||||||
|
elif self.counting_dict[track_id] != current_position and not is_inside:
|
||||||
self.out_counts += 1
|
self.out_counts += 1
|
||||||
|
self.counting_dict[track_id] = "out"
|
||||||
|
else:
|
||||||
|
self.counting_dict[track_id] = current_position
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.counting_dict[track_id] = current_position
|
||||||
|
|
||||||
|
elif len(self.reg_pts) == 2:
|
||||||
|
if prev_position is not None:
|
||||||
|
is_inside = (box[0] - prev_position[0]) * (
|
||||||
|
self.counting_region.centroid.x - prev_position[0]
|
||||||
|
) > 0
|
||||||
|
current_position = "in" if is_inside else "out"
|
||||||
|
|
||||||
|
if self.counting_dict[track_id] != current_position and is_inside:
|
||||||
|
self.in_counts += 1
|
||||||
|
self.counting_dict[track_id] = "in"
|
||||||
|
elif self.counting_dict[track_id] != current_position and not is_inside:
|
||||||
|
self.out_counts += 1
|
||||||
|
self.counting_dict[track_id] = "out"
|
||||||
|
else:
|
||||||
|
self.counting_dict[track_id] = current_position
|
||||||
|
else:
|
||||||
|
self.counting_dict[track_id] = None
|
||||||
|
|
||||||
incount_label = f"In Count : {self.in_counts}"
|
incount_label = f"In Count : {self.in_counts}"
|
||||||
outcount_label = f"OutCount : {self.out_counts}"
|
outcount_label = f"OutCount : {self.out_counts}"
|
||||||
@ -233,12 +249,11 @@ class ObjectCounter:
|
|||||||
def display_frames(self):
|
def display_frames(self):
|
||||||
"""Display frame."""
|
"""Display frame."""
|
||||||
if self.env_check:
|
if self.env_check:
|
||||||
cv2.namedWindow("Ultralytics YOLOv8 Object Counter")
|
self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
|
||||||
|
cv2.namedWindow(self.window_name)
|
||||||
if len(self.reg_pts) == 4: # only add mouse event If user drawn region
|
if len(self.reg_pts) == 4: # only add mouse event If user drawn region
|
||||||
cv2.setMouseCallback(
|
cv2.setMouseCallback(self.window_name, self.mouse_event_for_region, {"region_points": self.reg_pts})
|
||||||
"Ultralytics YOLOv8 Object Counter", self.mouse_event_for_region, {"region_points": self.reg_pts}
|
cv2.imshow(self.window_name, self.im0)
|
||||||
)
|
|
||||||
cv2.imshow("Ultralytics YOLOv8 Object Counter", self.im0)
|
|
||||||
# Break Window
|
# Break Window
|
||||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||||
return
|
return
|
||||||
@ -252,12 +267,7 @@ class ObjectCounter:
|
|||||||
tracks (list): List of tracks obtained from the object tracking process.
|
tracks (list): List of tracks obtained from the object tracking process.
|
||||||
"""
|
"""
|
||||||
self.im0 = im0 # store image
|
self.im0 = im0 # store image
|
||||||
|
self.extract_and_process_tracks(tracks) # draw region even if no objects
|
||||||
if tracks[0].boxes.id is None:
|
|
||||||
if self.view_img:
|
|
||||||
self.display_frames()
|
|
||||||
return im0
|
|
||||||
self.extract_and_process_tracks(tracks)
|
|
||||||
|
|
||||||
if self.view_img:
|
if self.view_img:
|
||||||
self.display_frames()
|
self.display_frames()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user