mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 08:56:11 +08:00 
			
		
		
		
	Object Counter improvements (#8648)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
		
							parent
							
								
									ddc94a6981
								
							
						
					
					
						commit
						3596a77a5a
					
				@ -38,11 +38,12 @@ class ObjectCounter:
 | 
			
		||||
 | 
			
		||||
        self.names = None  # Classes names
 | 
			
		||||
        self.annotator = None  # Annotator
 | 
			
		||||
        self.window_name = "Ultralytics YOLOv8 Object Counter"
 | 
			
		||||
 | 
			
		||||
        # Object counting Information
 | 
			
		||||
        self.in_counts = 0
 | 
			
		||||
        self.out_counts = 0
 | 
			
		||||
        self.counting_list = []
 | 
			
		||||
        self.counting_dict = {}
 | 
			
		||||
        self.count_txt_thickness = 0
 | 
			
		||||
        self.count_txt_color = (0, 0, 0)
 | 
			
		||||
        self.count_color = (255, 255, 255)
 | 
			
		||||
@ -106,12 +107,12 @@ class ObjectCounter:
 | 
			
		||||
            print("Line Counter Initiated.")
 | 
			
		||||
            self.reg_pts = reg_pts
 | 
			
		||||
            self.counting_region = LineString(self.reg_pts)
 | 
			
		||||
        elif len(reg_pts) == 4:
 | 
			
		||||
        elif len(reg_pts) >= 3:
 | 
			
		||||
            print("Region Counter Initiated.")
 | 
			
		||||
            self.reg_pts = reg_pts
 | 
			
		||||
            self.counting_region = Polygon(self.reg_pts)
 | 
			
		||||
        else:
 | 
			
		||||
            print("Invalid Region points provided, region_points can be 2 or 4")
 | 
			
		||||
            print("Invalid Region points provided, region_points must be 2 for lines or >= 3 for polygons.")
 | 
			
		||||
            print("Using Line Counter Now")
 | 
			
		||||
            self.counting_region = LineString(self.reg_pts)
 | 
			
		||||
 | 
			
		||||
@ -158,55 +159,70 @@ class ObjectCounter:
 | 
			
		||||
 | 
			
		||||
    def extract_and_process_tracks(self, tracks):
 | 
			
		||||
        """Extracts and processes tracks for object counting in a video stream."""
 | 
			
		||||
        boxes = tracks[0].boxes.xyxy.cpu()
 | 
			
		||||
        clss = tracks[0].boxes.cls.cpu().tolist()
 | 
			
		||||
        track_ids = tracks[0].boxes.id.int().cpu().tolist()
 | 
			
		||||
 | 
			
		||||
        # Annotator Init and region drawing
 | 
			
		||||
        self.annotator = Annotator(self.im0, self.tf, self.names)
 | 
			
		||||
        self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
 | 
			
		||||
 | 
			
		||||
        # Extract tracks
 | 
			
		||||
        for box, track_id, cls in zip(boxes, track_ids, clss):
 | 
			
		||||
            # Draw bounding box
 | 
			
		||||
            self.annotator.box_label(box, label=f"{track_id}:{self.names[cls]}", color=colors(int(cls), True))
 | 
			
		||||
        if tracks[0].boxes.id is not None:
 | 
			
		||||
            boxes = tracks[0].boxes.xyxy.cpu()
 | 
			
		||||
            clss = tracks[0].boxes.cls.cpu().tolist()
 | 
			
		||||
            track_ids = tracks[0].boxes.id.int().cpu().tolist()
 | 
			
		||||
 | 
			
		||||
            # Draw Tracks
 | 
			
		||||
            track_line = self.track_history[track_id]
 | 
			
		||||
            track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2)))
 | 
			
		||||
            if len(track_line) > 30:
 | 
			
		||||
                track_line.pop(0)
 | 
			
		||||
            # Extract tracks
 | 
			
		||||
            for box, track_id, cls in zip(boxes, track_ids, clss):
 | 
			
		||||
                # Draw bounding box
 | 
			
		||||
                self.annotator.box_label(box, label=f"{track_id}:{self.names[cls]}", color=colors(int(cls), True))
 | 
			
		||||
 | 
			
		||||
            # Draw track trails
 | 
			
		||||
            if self.draw_tracks:
 | 
			
		||||
                self.annotator.draw_centroid_and_tracks(
 | 
			
		||||
                    track_line, color=self.track_color, track_thickness=self.track_thickness
 | 
			
		||||
                )
 | 
			
		||||
                # Draw Tracks
 | 
			
		||||
                track_line = self.track_history[track_id]
 | 
			
		||||
                track_line.append((float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2)))
 | 
			
		||||
                if len(track_line) > 30:
 | 
			
		||||
                    track_line.pop(0)
 | 
			
		||||
 | 
			
		||||
            prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
 | 
			
		||||
                # Draw track trails
 | 
			
		||||
                if self.draw_tracks:
 | 
			
		||||
                    self.annotator.draw_centroid_and_tracks(
 | 
			
		||||
                        track_line, color=self.track_color, track_thickness=self.track_thickness
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
            # Count objects
 | 
			
		||||
            if len(self.reg_pts) == 4:
 | 
			
		||||
                if (
 | 
			
		||||
                    prev_position is not None
 | 
			
		||||
                    and self.counting_region.contains(Point(track_line[-1]))
 | 
			
		||||
                    and track_id not in self.counting_list
 | 
			
		||||
                ):
 | 
			
		||||
                    self.counting_list.append(track_id)
 | 
			
		||||
                    if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
 | 
			
		||||
                        self.in_counts += 1
 | 
			
		||||
                    else:
 | 
			
		||||
                        self.out_counts += 1
 | 
			
		||||
                prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
 | 
			
		||||
                centroid = Point((box[:2] + box[2:]) / 2)
 | 
			
		||||
 | 
			
		||||
            elif len(self.reg_pts) == 2:
 | 
			
		||||
                if prev_position is not None:
 | 
			
		||||
                    distance = Point(track_line[-1]).distance(self.counting_region)
 | 
			
		||||
                    if distance < self.line_dist_thresh and track_id not in self.counting_list:
 | 
			
		||||
                        self.counting_list.append(track_id)
 | 
			
		||||
                        if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
 | 
			
		||||
                # Count objects
 | 
			
		||||
                if len(self.reg_pts) >= 3:  # any polygon
 | 
			
		||||
                    is_inside = self.counting_region.contains(centroid)
 | 
			
		||||
                    current_position = "in" if is_inside else "out"
 | 
			
		||||
 | 
			
		||||
                    if prev_position is not None:
 | 
			
		||||
                        if self.counting_dict[track_id] != current_position and is_inside:
 | 
			
		||||
                            self.in_counts += 1
 | 
			
		||||
                        else:
 | 
			
		||||
                            self.counting_dict[track_id] = "in"
 | 
			
		||||
                        elif self.counting_dict[track_id] != current_position and not is_inside:
 | 
			
		||||
                            self.out_counts += 1
 | 
			
		||||
                            self.counting_dict[track_id] = "out"
 | 
			
		||||
                        else:
 | 
			
		||||
                            self.counting_dict[track_id] = current_position
 | 
			
		||||
 | 
			
		||||
                    else:
 | 
			
		||||
                        self.counting_dict[track_id] = current_position
 | 
			
		||||
 | 
			
		||||
                elif len(self.reg_pts) == 2:
 | 
			
		||||
                    if prev_position is not None:
 | 
			
		||||
                        is_inside = (box[0] - prev_position[0]) * (
 | 
			
		||||
                            self.counting_region.centroid.x - prev_position[0]
 | 
			
		||||
                        ) > 0
 | 
			
		||||
                        current_position = "in" if is_inside else "out"
 | 
			
		||||
 | 
			
		||||
                        if self.counting_dict[track_id] != current_position and is_inside:
 | 
			
		||||
                            self.in_counts += 1
 | 
			
		||||
                            self.counting_dict[track_id] = "in"
 | 
			
		||||
                        elif self.counting_dict[track_id] != current_position and not is_inside:
 | 
			
		||||
                            self.out_counts += 1
 | 
			
		||||
                            self.counting_dict[track_id] = "out"
 | 
			
		||||
                        else:
 | 
			
		||||
                            self.counting_dict[track_id] = current_position
 | 
			
		||||
                    else:
 | 
			
		||||
                        self.counting_dict[track_id] = None
 | 
			
		||||
 | 
			
		||||
        incount_label = f"In Count : {self.in_counts}"
 | 
			
		||||
        outcount_label = f"OutCount : {self.out_counts}"
 | 
			
		||||
@ -233,12 +249,11 @@ class ObjectCounter:
 | 
			
		||||
    def display_frames(self):
 | 
			
		||||
        """Display frame."""
 | 
			
		||||
        if self.env_check:
 | 
			
		||||
            cv2.namedWindow("Ultralytics YOLOv8 Object Counter")
 | 
			
		||||
            self.annotator.draw_region(reg_pts=self.reg_pts, color=self.region_color, thickness=self.region_thickness)
 | 
			
		||||
            cv2.namedWindow(self.window_name)
 | 
			
		||||
            if len(self.reg_pts) == 4:  # only add mouse event If user drawn region
 | 
			
		||||
                cv2.setMouseCallback(
 | 
			
		||||
                    "Ultralytics YOLOv8 Object Counter", self.mouse_event_for_region, {"region_points": self.reg_pts}
 | 
			
		||||
                )
 | 
			
		||||
            cv2.imshow("Ultralytics YOLOv8 Object Counter", self.im0)
 | 
			
		||||
                cv2.setMouseCallback(self.window_name, self.mouse_event_for_region, {"region_points": self.reg_pts})
 | 
			
		||||
            cv2.imshow(self.window_name, self.im0)
 | 
			
		||||
            # Break Window
 | 
			
		||||
            if cv2.waitKey(1) & 0xFF == ord("q"):
 | 
			
		||||
                return
 | 
			
		||||
@ -252,12 +267,7 @@ class ObjectCounter:
 | 
			
		||||
            tracks (list): List of tracks obtained from the object tracking process.
 | 
			
		||||
        """
 | 
			
		||||
        self.im0 = im0  # store image
 | 
			
		||||
 | 
			
		||||
        if tracks[0].boxes.id is None:
 | 
			
		||||
            if self.view_img:
 | 
			
		||||
                self.display_frames()
 | 
			
		||||
            return im0
 | 
			
		||||
        self.extract_and_process_tracks(tracks)
 | 
			
		||||
        self.extract_and_process_tracks(tracks)  # draw region even if no objects
 | 
			
		||||
 | 
			
		||||
        if self.view_img:
 | 
			
		||||
            self.display_frames()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user