Support FastSAM directory inference and plot (#4634)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Gezhi Zhang 2023-08-30 20:19:36 +08:00 committed by GitHub
parent 620335de27
commit 8596ee241f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 136 additions and 120 deletions

View File

@ -22,7 +22,7 @@ class FastSAMPredictor(DetectionPredictor):
max_det=self.args.max_det, max_det=self.args.max_det,
nc=len(self.model.names), nc=len(self.model.names),
classes=self.args.classes) classes=self.args.classes)
full_box = torch.zeros_like(p[0][0]) full_box = torch.zeros(p[0].shape[1])
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
full_box = full_box.view(1, -1) full_box = full_box.view(1, -1)
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:]) critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])

View File

@ -8,18 +8,17 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from tqdm import tqdm
from ultralytics.utils import LOGGER from ultralytics.utils import TQDM_BAR_FORMAT
class FastSAMPrompt: class FastSAMPrompt:
def __init__(self, img_path, results, device='cuda') -> None: def __init__(self, source, results, device='cuda') -> None:
# self.img_path = img_path
self.device = device self.device = device
self.results = results self.results = results
self.img_path = str(img_path) self.source = source
self.ori_img = cv2.imread(self.img_path)
# Import and assign clip # Import and assign clip
try: try:
@ -48,7 +47,7 @@ class FastSAMPrompt:
@staticmethod @staticmethod
def _format_results(result, filter=0): def _format_results(result, filter=0):
annotations = [] annotations = []
n = len(result.masks.data) n = len(result.masks.data) if result.masks is not None else 0
for i in range(n): for i in range(n):
mask = result.masks.data[i] == 1.0 mask = result.masks.data[i] == 1.0
if torch.sum(mask) >= filter: if torch.sum(mask) >= filter:
@ -86,32 +85,34 @@ class FastSAMPrompt:
mask_random_color=True, mask_random_color=True,
better_quality=True, better_quality=True,
retina=False, retina=False,
with_countouers=True): withContours=True):
if isinstance(annotations[0], dict): n = len(annotations)
annotations = [annotation['segmentation'] for annotation in annotations] pbar = tqdm(annotations, total=n, bar_format=TQDM_BAR_FORMAT)
if isinstance(annotations, torch.Tensor): for ann in pbar:
annotations = annotations.cpu().numpy() result_name = os.path.basename(ann.path)
result_name = os.path.basename(self.img_path) image = ann.orig_img
image = self.ori_img original_h, original_w = ann.orig_shape
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
original_h = image.shape[0]
original_w = image.shape[1]
# for macOS only # for macOS only
# plt.switch_backend('TkAgg') # plt.switch_backend('TkAgg')
fig = plt.figure(figsize=(original_w / 100, original_h / 100)) plt.figure(figsize=(original_w / 100, original_h / 100))
# Add subplot with no margin. # Add subplot with no margin.
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0) plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(image) plt.imshow(image)
if ann.masks is not None:
masks = ann.masks.data
if better_quality: if better_quality:
for i, mask in enumerate(annotations): if isinstance(masks[0], torch.Tensor):
masks = np.array(masks.cpu())
for i, mask in enumerate(masks):
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
self.fast_show_mask( self.fast_show_mask(
annotations, masks,
plt.gca(), plt.gca(),
random_color=mask_random_color, random_color=mask_random_color,
bbox=bbox, bbox=bbox,
@ -122,33 +123,41 @@ class FastSAMPrompt:
target_width=original_w, target_width=original_w,
) )
if with_countouers: if withContours:
contour_all = [] contour_all = []
temp = np.zeros((original_h, original_w, 1)) temp = np.zeros((original_h, original_w, 1))
for i, mask in enumerate(annotations): for i, mask in enumerate(masks):
if isinstance(mask, dict): mask = mask.astype(np.uint8)
mask = mask['segmentation']
annotation = mask.astype(np.uint8)
if not retina: if not retina:
annotation = cv2.resize( mask = cv2.resize(
annotation, mask,
(original_w, original_h), (original_w, original_h),
interpolation=cv2.INTER_NEAREST, interpolation=cv2.INTER_NEAREST,
) )
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
contour_all.extend(iter(contours)) contour_all.extend(iter(contours))
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 1.0, 0.8]) color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1) contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask) plt.imshow(contour_mask)
plt.axis('off')
fig = plt.gcf()
try:
buf = fig.canvas.tostring_rgb()
except AttributeError:
fig.canvas.draw()
buf = fig.canvas.tostring_rgb()
cols, rows = fig.canvas.get_width_height()
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
save_path = Path(output) / result_name save_path = Path(output) / result_name
save_path.parent.mkdir(exist_ok=True, parents=True) save_path.parent.mkdir(exist_ok=True, parents=True)
plt.axis('off') cv2.imwrite(str(save_path), img_array)
fig.savefig(save_path) plt.close()
LOGGER.info(f'Saved to {save_path.absolute()}') pbar.set_description('Saving {} to {}'.format(result_name, save_path))
# CPU post process
@staticmethod @staticmethod
def fast_show_mask( def fast_show_mask(
annotation, annotation,
@ -215,8 +224,9 @@ class FastSAMPrompt:
return probs[:, 0].softmax(dim=0) return probs[:, 0].softmax(dim=0)
def _crop_image(self, format_results): def _crop_image(self, format_results):
if os.path.isdir(self.source):
image = Image.fromarray(cv2.cvtColor(self.ori_img, cv2.COLOR_BGR2RGB)) raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
ori_w, ori_h = image.size ori_w, ori_h = image.size
annotations = format_results annotations = format_results
mask_h, mask_w = annotations[0]['segmentation'].shape mask_h, mask_w = annotations[0]['segmentation'].shape
@ -237,11 +247,12 @@ class FastSAMPrompt:
return cropped_boxes, cropped_images, not_crop, filter_id, annotations return cropped_boxes, cropped_images, not_crop, filter_id, annotations
def box_prompt(self, bbox): def box_prompt(self, bbox):
if self.results[0].masks is not None:
assert (bbox[2] != 0 and bbox[3] != 0) assert (bbox[2] != 0 and bbox[3] != 0)
if os.path.isdir(self.source):
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
masks = self.results[0].masks.data masks = self.results[0].masks.data
target_height = self.ori_img.shape[0] target_height, target_width = self.results[0].orig_shape
target_width = self.ori_img.shape[1]
h = masks.shape[1] h = masks.shape[1]
w = masks.shape[2] w = masks.shape[2]
if h != target_height or w != target_width: if h != target_height or w != target_width:
@ -265,13 +276,15 @@ class FastSAMPrompt:
IoUs = masks_area / union IoUs = masks_area / union
max_iou_index = torch.argmax(IoUs) max_iou_index = torch.argmax(IoUs)
return np.array([masks[max_iou_index].cpu().numpy()]) self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
return self.results
def point_prompt(self, points, pointlabel): # numpy 处理 def point_prompt(self, points, pointlabel): # numpy 处理
if self.results[0].masks is not None:
if os.path.isdir(self.source):
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
masks = self._format_results(self.results[0], 0) masks = self._format_results(self.results[0], 0)
target_height = self.ori_img.shape[0] target_height, target_width = self.results[0].orig_shape
target_width = self.ori_img.shape[1]
h = masks[0]['segmentation'].shape[0] h = masks[0]['segmentation'].shape[0]
w = masks[0]['segmentation'].shape[1] w = masks[0]['segmentation'].shape[1]
if h != target_height or w != target_width: if h != target_height or w != target_width:
@ -285,9 +298,11 @@ class FastSAMPrompt:
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
onemask -= mask onemask -= mask
onemask = onemask >= 1 onemask = onemask >= 1
return np.array([onemask]) self.results[0].masks.data = torch.tensor(np.array([onemask]))
return self.results
def text_prompt(self, text): def text_prompt(self, text):
if self.results[0].masks is not None:
format_results = self._format_results(self.results[0], 0) format_results = self._format_results(self.results[0], 0)
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device) clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
@ -295,7 +310,8 @@ class FastSAMPrompt:
max_idx = scores.argsort() max_idx = scores.argsort()
max_idx = max_idx[-1] max_idx = max_idx[-1]
max_idx += sum(np.array(filter_id) <= int(max_idx)) max_idx += sum(np.array(filter_id) <= int(max_idx))
return np.array([annotations[max_idx]['segmentation']]) self.results[0].masks.data = torch.tensor(np.array([ann['segmentation'] for ann in annotations]))
return self.results
def everything_prompt(self): def everything_prompt(self):
return self.results[0].masks.data return self.results