mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
peoplenet working
This commit is contained in:
parent
e9d9e33559
commit
72259392ed
203
peoplenet.py
Normal file
203
peoplenet.py
Normal file
@ -0,0 +1,203 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import List, Tuple
|
||||
from sklearn.cluster import DBSCAN
|
||||
import tritonclient.http as httpclient
|
||||
from tritonclient.utils import triton_to_np_dtype
|
||||
|
||||
|
||||
def read_image(image_path: str) -> np.ndarray:
|
||||
"""
|
||||
Read an image using OpenCV.
|
||||
|
||||
Args:
|
||||
image_path (str): Path to the image file
|
||||
|
||||
Returns:
|
||||
np.ndarray: Image array in BGR format
|
||||
"""
|
||||
return cv2.imread(image_path)
|
||||
|
||||
|
||||
def preprocess(image: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Preprocess the input image for PeopleNet model.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): Input image array in BGR format
|
||||
|
||||
Returns:
|
||||
np.ndarray: Preprocessed image array of shape (1, 3, 544, 960)
|
||||
"""
|
||||
# Convert BGR to RGB
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Resize the image to 960x544
|
||||
image = cv2.resize(image, (960, 544))
|
||||
|
||||
# Normalize the image
|
||||
image = image.astype(np.float32) / 255.0
|
||||
|
||||
# Transpose from (H, W, C) to (C, H, W)
|
||||
image = image.transpose(2, 0, 1)
|
||||
|
||||
# Add batch dimension
|
||||
image = np.expand_dims(image, axis=0)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def run_inference(triton_client: httpclient.InferenceServerClient, preprocessed_image: np.ndarray, model_name: str) -> \
|
||||
Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Run inference using Triton Inference Server.
|
||||
|
||||
Args:
|
||||
triton_client (httpclient.InferenceServerClient): Triton client object
|
||||
preprocessed_image (np.ndarray): Preprocessed image array
|
||||
model_name (str): Name of the model on Triton server
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: Coverage and bounding box output tensors
|
||||
"""
|
||||
# Prepare the input data
|
||||
input_name = "input_1:0" # Adjust if your model uses a different input name
|
||||
inputs = [httpclient.InferInput(input_name, preprocessed_image.shape, datatype="FP32")]
|
||||
inputs[0].set_data_from_numpy(preprocessed_image)
|
||||
|
||||
# Run inference
|
||||
outputs = [
|
||||
httpclient.InferRequestedOutput("output_cov/Sigmoid:0"), # Adjust if your model uses different output names
|
||||
httpclient.InferRequestedOutput("output_bbox/BiasAdd:0")
|
||||
]
|
||||
response = triton_client.infer(model_name, inputs, outputs=outputs)
|
||||
|
||||
# Get the output data
|
||||
cov = response.as_numpy("output_cov/Sigmoid:0")
|
||||
bbox = response.as_numpy("output_bbox/BiasAdd:0")
|
||||
|
||||
return cov, bbox
|
||||
|
||||
|
||||
def postprocess(
|
||||
cov: np.ndarray,
|
||||
bbox: np.ndarray,
|
||||
confidence_threshold: float = 0.5,
|
||||
eps: float = 0.2,
|
||||
min_samples: int = 1
|
||||
) -> List[Tuple[str, Tuple[float, float, float, float], float]]:
|
||||
"""
|
||||
Postprocess the model output to get final detections.
|
||||
|
||||
Args:
|
||||
cov (np.ndarray): Coverage array of shape (1, 3, 34, 60)
|
||||
bbox (np.ndarray): Bounding box array of shape (1, 12, 34, 60)
|
||||
confidence_threshold (float): Confidence threshold for filtering detections
|
||||
eps (float): DBSCAN epsilon parameter
|
||||
min_samples (int): DBSCAN min_samples parameter
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Tuple[float, float, float, float], float]]:
|
||||
List of (class_name, (x, y, w, h), confidence) tuples
|
||||
"""
|
||||
classes = ['Bag', 'Face', 'Person']
|
||||
results = []
|
||||
|
||||
for class_idx, class_name in enumerate(classes):
|
||||
# Extract class-specific arrays
|
||||
class_cov = cov[0, class_idx]
|
||||
class_bbox = bbox[0, class_idx * 4:(class_idx + 1) * 4]
|
||||
|
||||
# Filter by confidence
|
||||
mask = class_cov > confidence_threshold
|
||||
confident_cov = class_cov[mask]
|
||||
confident_bbox = class_bbox[:, mask]
|
||||
|
||||
if confident_cov.size == 0:
|
||||
continue
|
||||
|
||||
# Prepare data for clustering
|
||||
grid_y, grid_x = np.mgrid[0:34, 0:60]
|
||||
grid_points = np.column_stack((grid_x[mask], grid_y[mask]))
|
||||
|
||||
# Cluster detections
|
||||
clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(grid_points)
|
||||
labels = clustering.labels_
|
||||
|
||||
for cluster_id in range(labels.max() + 1):
|
||||
cluster_mask = labels == cluster_id
|
||||
cluster_cov = confident_cov[cluster_mask]
|
||||
cluster_bbox = confident_bbox[:, cluster_mask]
|
||||
|
||||
# Compute weighted average of bounding box coordinates
|
||||
weights = cluster_cov / cluster_cov.sum()
|
||||
avg_bbox = (cluster_bbox * weights).sum(axis=1)
|
||||
|
||||
# Convert to (x, y, w, h) format
|
||||
x1, y1, x2, y2 = avg_bbox.tolist()
|
||||
x, y = (x1 + x2) / 2, (y1 + y2) / 2
|
||||
w, h = x2 - x1, y2 - y1
|
||||
|
||||
# Add to results
|
||||
confidence = cluster_cov.max().item()
|
||||
results.append((class_name, (x, y, w, h), confidence))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def process_image(triton_url: str, model_name: str, image_path: str) -> List[
|
||||
Tuple[str, Tuple[float, float, float, float], float]]:
|
||||
"""
|
||||
Process an image through the PeopleNet model using Triton Inference Server.
|
||||
|
||||
Args:
|
||||
triton_url (str): URL of the Triton Inference Server
|
||||
model_name (str): Name of the model on Triton server
|
||||
image_path (str): Path to the input image file
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Tuple[float, float, float, float], float]]:
|
||||
List of (class_name, (x, y, w, h), confidence) tuples
|
||||
"""
|
||||
# Create Triton client
|
||||
triton_client = httpclient.InferenceServerClient(url=triton_url)
|
||||
|
||||
# Read the image
|
||||
image = read_image(image_path)
|
||||
|
||||
# Preprocess
|
||||
preprocessed_image = preprocess(image)
|
||||
|
||||
# Run inference
|
||||
cov, bbox = run_inference(triton_client, preprocessed_image, model_name)
|
||||
|
||||
# Postprocess
|
||||
detections = postprocess(cov, bbox)
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
# Usage example
|
||||
if __name__ == "__main__":
|
||||
triton_url = "192.168.0.22:8000" # Adjust this to your Triton server's address
|
||||
model_name = "peoplenet" # Adjust this to your model's name on Triton
|
||||
image_path = "ultralytics/assets/83.jpg"
|
||||
|
||||
results = process_image(triton_url, model_name, image_path)
|
||||
|
||||
# Print results
|
||||
for class_name, (x, y, w, h), confidence in results:
|
||||
print(f"Class: {class_name}, Bbox: ({x:.2f}, {y:.2f}, {w:.2f}, {h:.2f}), Confidence: {confidence:.2f}")
|
||||
|
||||
# Optionally, visualize results on the image
|
||||
image = cv2.imread(image_path)
|
||||
for class_name, (x, y, w, h), confidence in results:
|
||||
x1, y1 = int(x - w / 2), int(y - h / 2)
|
||||
x2, y2 = int(x + w / 2), int(y + h / 2)
|
||||
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
cv2.putText(image, f"{class_name}: {confidence:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0),
|
||||
2)
|
||||
|
||||
cv2.imshow("PeopleNet Detections", image)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
Loading…
x
Reference in New Issue
Block a user