mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-06-12 11:34:21 +08:00
Fix FastSAM canvas drawing bug (#4705)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
187b504d68
commit
2ba80e355a
@ -44,15 +44,15 @@ from ultralytics.utils.files import increment_path
|
|||||||
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
||||||
|
|
||||||
STREAM_WARNING = """
|
STREAM_WARNING = """
|
||||||
WARNING ⚠️ stream/video/webcam/dir predict source will accumulate results in RAM unless `stream=True` is passed,
|
WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
|
||||||
causing potential out-of-memory errors for large sources or long-running streams/videos.
|
errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
results = model(source=..., stream=True) # generator of Results objects
|
results = model(source=..., stream=True) # generator of Results objects
|
||||||
for r in results:
|
for r in results:
|
||||||
boxes = r.boxes # Boxes object for bbox outputs
|
boxes = r.boxes # Boxes object for bbox outputs
|
||||||
masks = r.masks # Masks object for segment masks outputs
|
masks = r.masks # Masks object for segment masks outputs
|
||||||
probs = r.probs # Class probabilities for classification outputs
|
probs = r.probs # Class probabilities for classification outputs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,7 +68,6 @@ class FastSAMPrompt:
|
|||||||
if len(contours) > 1:
|
if len(contours) > 1:
|
||||||
for b in contours:
|
for b in contours:
|
||||||
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
|
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
|
||||||
# 将多个bbox合并成一个
|
|
||||||
x1 = min(x1, x_t)
|
x1 = min(x1, x_t)
|
||||||
y1 = min(y1, y_t)
|
y1 = min(y1, y_t)
|
||||||
x2 = max(x2, x_t + w_t)
|
x2 = max(x2, x_t + w_t)
|
||||||
@ -84,9 +83,8 @@ class FastSAMPrompt:
|
|||||||
mask_random_color=True,
|
mask_random_color=True,
|
||||||
better_quality=True,
|
better_quality=True,
|
||||||
retina=False,
|
retina=False,
|
||||||
withContours=True):
|
with_contours=True):
|
||||||
n = len(annotations)
|
pbar = TQDM(annotations, total=len(annotations))
|
||||||
pbar = TQDM(annotations, total=n)
|
|
||||||
for ann in pbar:
|
for ann in pbar:
|
||||||
result_name = os.path.basename(ann.path)
|
result_name = os.path.basename(ann.path)
|
||||||
image = ann.orig_img
|
image = ann.orig_img
|
||||||
@ -122,17 +120,13 @@ class FastSAMPrompt:
|
|||||||
target_width=original_w,
|
target_width=original_w,
|
||||||
)
|
)
|
||||||
|
|
||||||
if withContours:
|
if with_contours:
|
||||||
contour_all = []
|
contour_all = []
|
||||||
temp = np.zeros((original_h, original_w, 1))
|
temp = np.zeros((original_h, original_w, 1))
|
||||||
for i, mask in enumerate(masks):
|
for i, mask in enumerate(masks):
|
||||||
mask = mask.astype(np.uint8)
|
mask = mask.astype(np.uint8)
|
||||||
if not retina:
|
if not retina:
|
||||||
mask = cv2.resize(
|
mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
|
||||||
mask,
|
|
||||||
(original_w, original_h),
|
|
||||||
interpolation=cv2.INTER_NEAREST,
|
|
||||||
)
|
|
||||||
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
contour_all.extend(iter(contours))
|
contour_all.extend(iter(contours))
|
||||||
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
|
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
|
||||||
@ -143,17 +137,14 @@ class FastSAMPrompt:
|
|||||||
plt.axis('off')
|
plt.axis('off')
|
||||||
fig = plt.gcf()
|
fig = plt.gcf()
|
||||||
|
|
||||||
try:
|
# Check if the canvas has been drawn
|
||||||
buf = fig.canvas.tostring_rgb()
|
if fig.canvas.get_renderer() is None: # macOS requires this or tests fail
|
||||||
except AttributeError:
|
|
||||||
fig.canvas.draw()
|
fig.canvas.draw()
|
||||||
buf = fig.canvas.tostring_rgb()
|
|
||||||
cols, rows = fig.canvas.get_width_height()
|
|
||||||
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
|
|
||||||
|
|
||||||
save_path = Path(output) / result_name
|
save_path = Path(output) / result_name
|
||||||
save_path.parent.mkdir(exist_ok=True, parents=True)
|
save_path.parent.mkdir(exist_ok=True, parents=True)
|
||||||
cv2.imwrite(str(save_path), img_array)
|
image = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
||||||
|
image.save(save_path)
|
||||||
plt.close()
|
plt.close()
|
||||||
pbar.set_description(f'Saving {result_name} to {save_path}')
|
pbar.set_description(f'Saving {result_name} to {save_path}')
|
||||||
|
|
||||||
@ -289,7 +280,7 @@ class FastSAMPrompt:
|
|||||||
if h != target_height or w != target_width:
|
if h != target_height or w != target_width:
|
||||||
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
||||||
onemask = np.zeros((h, w))
|
onemask = np.zeros((h, w))
|
||||||
for i, annotation in enumerate(masks):
|
for annotation in masks:
|
||||||
mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation
|
mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation
|
||||||
for i, point in enumerate(points):
|
for i, point in enumerate(points):
|
||||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
||||||
|
@ -263,7 +263,7 @@ class ProfileModels:
|
|||||||
data = clipped_data
|
data = clipped_data
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-7):
|
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
|
||||||
if not self.trt or not Path(engine_file).is_file():
|
if not self.trt or not Path(engine_file).is_file():
|
||||||
return 0.0, 0.0
|
return 0.0, 0.0
|
||||||
|
|
||||||
@ -291,7 +291,7 @@ class ProfileModels:
|
|||||||
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
|
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
|
||||||
return np.mean(run_times), np.std(run_times)
|
return np.mean(run_times), np.std(run_times)
|
||||||
|
|
||||||
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-7):
|
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
|
||||||
check_requirements('onnxruntime')
|
check_requirements('onnxruntime')
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user