diff --git a/peoplenet.py b/peoplenet.py index 727f588a..ea008a42 100644 --- a/peoplenet.py +++ b/peoplenet.py @@ -85,27 +85,44 @@ def postprocess( cov: np.ndarray, bbox: np.ndarray, original_dims: Tuple[int, int], - confidence_threshold: float = 0.1, - min_distance: int = 10, - min_size: int = 5 + confidence_thresholds: Dict[str, float] = { + 'Person': 0.3, # Lower threshold to catch more people + 'Face': 0.5, + 'Bag': 0.8 # Higher threshold to reduce false positives + }, + min_distance: Dict[str, int] = { + 'Person': 100, # Larger distance for person detection + 'Face': 30, + 'Bag': 50 + }, + min_size: Dict[str, int] = { + 'Person': 20, # Larger minimum size for person detection + 'Face': 10, + 'Bag': 15 + }, + box_scale_factor: Dict[str, float] = { + 'Person': 1.3, # Larger scaling for person boxes + 'Face': 1.1, + 'Bag': 1.0 + } ) -> List[Tuple[str, Tuple[float, float, float, float], float]]: """ - Enhanced postprocessing using heatmap-based detection and region growing. - - Args: - cov (np.ndarray): Coverage array of shape (1, 3, 34, 60) - bbox (np.ndarray): Bounding box array of shape (1, 12, 34, 60) - original_dims (Tuple[int, int]): Original image dimensions (width, height) - confidence_threshold (float): Confidence threshold for filtering detections - min_distance (int): Minimum distance between peaks - min_size (int): Minimum size for valid detections + Enhanced postprocessing with better person detection and overlap handling. """ - classes = ['Bag', 'Face', 'Person'] + classes = ['Person', 'Face', 'Bag'] # Reorder to prioritize person detection results = [] orig_height, orig_width = original_dims[1], original_dims[0] - for class_idx, class_name in enumerate(classes): + for class_name in classes: + class_idx = ['Bag', 'Face', 'Person'].index(class_name) # Map to original model output order + + # Get class-specific parameters + threshold = confidence_thresholds[class_name] + distance = min_distance[class_name] + size = min_size[class_name] + scale = box_scale_factor[class_name] + # Extract heatmap for current class heatmap = cov[0, class_idx] @@ -116,11 +133,18 @@ def postprocess( if heatmap.max() > heatmap.min(): heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min()) + # For person class, apply additional processing + if class_name == 'Person': + # Enhance person detection sensitivity + heatmap = cv2.GaussianBlur(heatmap, (5, 5), 0) + heatmap = np.power(heatmap, 0.7) # Reduce sensitivity to confidence threshold + # Find local maxima coordinates = peak_local_max( heatmap, - min_distance=min_distance, - threshold_abs=confidence_threshold + min_distance=distance, + threshold_abs=threshold, + exclude_border=False ) # Process each peak @@ -128,7 +152,7 @@ def postprocess( y, x = coord # Grow region around peak - binary = heatmap > (heatmap[y, x] * 0.4) # 40% of peak value + binary = heatmap > (heatmap[y, x] * 0.3) # More lenient region growing labeled, _ = ndimage.label(binary) region = labeled == labeled[y, x] @@ -138,29 +162,135 @@ def postprocess( x1, x2 = np.min(xs), np.max(xs) y1, y2 = np.min(ys), np.max(ys) - # Filter small detections - if (x2 - x1 >= min_size) and (y2 - y1 >= min_size): - # Convert to center format - center_x = (x1 + x2) / 2 - center_y = (y1 + y2) / 2 - width = x2 - x1 - height = y2 - y1 + # Scale the box + width = x2 - x1 + height = y2 - y1 + center_x = (x1 + x2) / 2 + center_y = (y1 + y2) / 2 + # Apply class-specific scaling + width *= scale + height *= scale + + # Ensure box stays within image bounds + width = min(width, orig_width - center_x, center_x) + height = min(height, orig_height - center_y, center_y) + + # Filter small detections + if width >= size and height >= size: # Get confidence from peak value confidence = heatmap[y, x] + # Add detection results.append(( class_name, (center_x, center_y, width, height), confidence )) - # Merge overlapping boxes - results = merge_overlapping_detections(results) + # Resolve overlapping detections + results = resolve_overlapping_detections(results) return results +def resolve_overlapping_detections( + detections: List[Tuple[str, Tuple[float, float, float, float], float]], + iou_threshold: float = 0.3 +) -> List[Tuple[str, Tuple[float, float, float, float], float]]: + """ + Resolve overlapping detections with class priority rules. + """ + if not detections: + return [] + + # Sort by class priority (Person > Face > Bag) and confidence + class_priority = {'Person': 0, 'Face': 1, 'Bag': 2} + detections = sorted(detections, + key=lambda x: (class_priority[x[0]], -x[2])) + + final_detections = [] + + while detections: + current = detections.pop(0) + current_box = current[1] # (x, y, w, h) + + # Check overlap with existing final detections + overlapping = False + for existing in final_detections: + existing_box = existing[1] + if calculate_iou(current_box, existing_box) > iou_threshold: + overlapping = True + break + + if not overlapping: + final_detections.append(current) + + return final_detections + + +def calculate_iou(box1: Tuple[float, float, float, float], + box2: Tuple[float, float, float, float]) -> float: + """ + Calculate IoU between two boxes in center format (x, y, w, h). + """ + # Convert to corner format + x1_1, y1_1 = box1[0] - box1[2] / 2, box1[1] - box1[3] / 2 + x2_1, y2_1 = box1[0] + box1[2] / 2, box1[1] + box1[3] / 2 + x1_2, y1_2 = box2[0] - box2[2] / 2, box2[1] - box2[3] / 2 + x2_2, y2_2 = box2[0] + box2[2] / 2, box2[1] + box2[3] / 2 + + # Calculate intersection + x1_i = max(x1_1, x1_2) + y1_i = max(y1_1, y1_2) + x2_i = min(x2_1, x2_2) + y2_i = min(y2_1, y2_2) + + if x2_i <= x1_i or y2_i <= y1_i: + return 0.0 + + intersection = (x2_i - x1_i) * (y2_i - y1_i) + + # Calculate areas + area1 = box1[2] * box1[3] + area2 = box2[2] * box2[3] + + # Calculate IoU + return intersection / (area1 + area2 - intersection) + + +def apply_class_rules( + detections: List[Tuple[str, Tuple[float, float, float, float], float]], + image_dims: Tuple[int, int] +) -> List[Tuple[str, Tuple[float, float, float, float], float]]: + """ + Apply class-specific rules to improve detection accuracy. + """ + filtered_detections = [] + + # Group detections by location for conflict resolution + location_groups = {} + for detection in detections: + class_name, (x, y, w, h), conf = detection + key = f"{int(x / 50)},{int(y / 50)}" # Group by grid cells + if key not in location_groups: + location_groups[key] = [] + location_groups[key].append(detection) + + # Process each location group + for group in location_groups.values(): + if len(group) > 1: + # If multiple detections in same area, prefer Person/Face over Bag + person_detections = [d for d in group if d[0] in ['Person', 'Face']] + if person_detections: + filtered_detections.extend(person_detections) + continue + + filtered_detections.extend(group) + + return filtered_detections + + def merge_overlapping_detections( detections: List[Tuple[str, Tuple[float, float, float, float], float]], iou_threshold: float = 0.5 @@ -469,7 +599,11 @@ def process_image(image_path: str, triton_url: str, model_name: str) -> Tuple[np output_bbox, output_cov, inference_time = client.run_inference(preprocessed) # Post-process detections - detections = postprocess(output_cov, output_bbox, (image.shape[1], image.shape[0])) + detections = postprocess( + output_cov, + output_bbox, + (image.shape[1], image.shape[0]) + ) # Draw detections result_image = draw_detections(image, detections)