Support FastSAM directory inference and plot (#4634)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Gezhi Zhang 2023-08-30 20:19:36 +08:00 committed by GitHub
parent 620335de27
commit 8596ee241f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 136 additions and 120 deletions

View File

@ -22,7 +22,7 @@ class FastSAMPredictor(DetectionPredictor):
max_det=self.args.max_det, max_det=self.args.max_det,
nc=len(self.model.names), nc=len(self.model.names),
classes=self.args.classes) classes=self.args.classes)
full_box = torch.zeros_like(p[0][0]) full_box = torch.zeros(p[0].shape[1])
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
full_box = full_box.view(1, -1) full_box = full_box.view(1, -1)
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:]) critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])

View File

@ -8,18 +8,17 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from tqdm import tqdm
from ultralytics.utils import LOGGER from ultralytics.utils import TQDM_BAR_FORMAT
class FastSAMPrompt: class FastSAMPrompt:
def __init__(self, img_path, results, device='cuda') -> None: def __init__(self, source, results, device='cuda') -> None:
# self.img_path = img_path
self.device = device self.device = device
self.results = results self.results = results
self.img_path = str(img_path) self.source = source
self.ori_img = cv2.imread(self.img_path)
# Import and assign clip # Import and assign clip
try: try:
@ -48,7 +47,7 @@ class FastSAMPrompt:
@staticmethod @staticmethod
def _format_results(result, filter=0): def _format_results(result, filter=0):
annotations = [] annotations = []
n = len(result.masks.data) n = len(result.masks.data) if result.masks is not None else 0
for i in range(n): for i in range(n):
mask = result.masks.data[i] == 1.0 mask = result.masks.data[i] == 1.0
if torch.sum(mask) >= filter: if torch.sum(mask) >= filter:
@ -86,69 +85,79 @@ class FastSAMPrompt:
mask_random_color=True, mask_random_color=True,
better_quality=True, better_quality=True,
retina=False, retina=False,
with_countouers=True): withContours=True):
if isinstance(annotations[0], dict): n = len(annotations)
annotations = [annotation['segmentation'] for annotation in annotations] pbar = tqdm(annotations, total=n, bar_format=TQDM_BAR_FORMAT)
if isinstance(annotations, torch.Tensor): for ann in pbar:
annotations = annotations.cpu().numpy() result_name = os.path.basename(ann.path)
result_name = os.path.basename(self.img_path) image = ann.orig_img
image = self.ori_img original_h, original_w = ann.orig_shape
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # for macOS only
original_h = image.shape[0] # plt.switch_backend('TkAgg')
original_w = image.shape[1] plt.figure(figsize=(original_w / 100, original_h / 100))
# for macOS only # Add subplot with no margin.
# plt.switch_backend('TkAgg') plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
fig = plt.figure(figsize=(original_w / 100, original_h / 100)) plt.margins(0, 0)
# Add subplot with no margin. plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.margins(0, 0) plt.imshow(image)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(image) if ann.masks is not None:
if better_quality: masks = ann.masks.data
for i, mask in enumerate(annotations): if better_quality:
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) if isinstance(masks[0], torch.Tensor):
annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) masks = np.array(masks.cpu())
self.fast_show_mask( for i, mask in enumerate(masks):
annotations, mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
plt.gca(), masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
random_color=mask_random_color,
bbox=bbox,
points=points,
pointlabel=point_label,
retinamask=retina,
target_height=original_h,
target_width=original_w,
)
if with_countouers: self.fast_show_mask(
contour_all = [] masks,
temp = np.zeros((original_h, original_w, 1)) plt.gca(),
for i, mask in enumerate(annotations): random_color=mask_random_color,
if isinstance(mask, dict): bbox=bbox,
mask = mask['segmentation'] points=points,
annotation = mask.astype(np.uint8) pointlabel=point_label,
if not retina: retinamask=retina,
annotation = cv2.resize( target_height=original_h,
annotation, target_width=original_w,
(original_w, original_h), )
interpolation=cv2.INTER_NEAREST,
)
contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
contour_all.extend(iter(contours))
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
save_path = Path(output) / result_name if withContours:
save_path.parent.mkdir(exist_ok=True, parents=True) contour_all = []
plt.axis('off') temp = np.zeros((original_h, original_w, 1))
fig.savefig(save_path) for i, mask in enumerate(masks):
LOGGER.info(f'Saved to {save_path.absolute()}') mask = mask.astype(np.uint8)
if not retina:
mask = cv2.resize(
mask,
(original_w, original_h),
interpolation=cv2.INTER_NEAREST,
)
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
contour_all.extend(iter(contours))
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
contour_mask = temp / 255 * color.reshape(1, 1, -1)
plt.imshow(contour_mask)
plt.axis('off')
fig = plt.gcf()
try:
buf = fig.canvas.tostring_rgb()
except AttributeError:
fig.canvas.draw()
buf = fig.canvas.tostring_rgb()
cols, rows = fig.canvas.get_width_height()
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
save_path = Path(output) / result_name
save_path.parent.mkdir(exist_ok=True, parents=True)
cv2.imwrite(str(save_path), img_array)
plt.close()
pbar.set_description('Saving {} to {}'.format(result_name, save_path))
# CPU post process
@staticmethod @staticmethod
def fast_show_mask( def fast_show_mask(
annotation, annotation,
@ -215,8 +224,9 @@ class FastSAMPrompt:
return probs[:, 0].softmax(dim=0) return probs[:, 0].softmax(dim=0)
def _crop_image(self, format_results): def _crop_image(self, format_results):
if os.path.isdir(self.source):
image = Image.fromarray(cv2.cvtColor(self.ori_img, cv2.COLOR_BGR2RGB)) raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
ori_w, ori_h = image.size ori_w, ori_h = image.size
annotations = format_results annotations = format_results
mask_h, mask_w = annotations[0]['segmentation'].shape mask_h, mask_w = annotations[0]['segmentation'].shape
@ -237,65 +247,71 @@ class FastSAMPrompt:
return cropped_boxes, cropped_images, not_crop, filter_id, annotations return cropped_boxes, cropped_images, not_crop, filter_id, annotations
def box_prompt(self, bbox): def box_prompt(self, bbox):
if self.results[0].masks is not None:
assert (bbox[2] != 0 and bbox[3] != 0)
if os.path.isdir(self.source):
raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
masks = self.results[0].masks.data
target_height, target_width = self.results[0].orig_shape
h = masks.shape[1]
w = masks.shape[2]
if h != target_height or w != target_width:
bbox = [
int(bbox[0] * w / target_width),
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ]
bbox[0] = max(round(bbox[0]), 0)
bbox[1] = max(round(bbox[1]), 0)
bbox[2] = min(round(bbox[2]), w)
bbox[3] = min(round(bbox[3]), h)
assert (bbox[2] != 0 and bbox[3] != 0) # IoUs = torch.zeros(len(masks), dtype=torch.float32)
masks = self.results[0].masks.data bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
target_height = self.ori_img.shape[0]
target_width = self.ori_img.shape[1]
h = masks.shape[1]
w = masks.shape[2]
if h != target_height or w != target_width:
bbox = [
int(bbox[0] * w / target_width),
int(bbox[1] * h / target_height),
int(bbox[2] * w / target_width),
int(bbox[3] * h / target_height), ]
bbox[0] = max(round(bbox[0]), 0)
bbox[1] = max(round(bbox[1]), 0)
bbox[2] = min(round(bbox[2]), w)
bbox[3] = min(round(bbox[3]), h)
# IoUs = torch.zeros(len(masks), dtype=torch.float32) masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2))
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) orig_masks_area = torch.sum(masks, dim=(1, 2))
masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2)) union = bbox_area + orig_masks_area - masks_area
orig_masks_area = torch.sum(masks, dim=(1, 2)) IoUs = masks_area / union
max_iou_index = torch.argmax(IoUs)
union = bbox_area + orig_masks_area - masks_area self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
IoUs = masks_area / union return self.results
max_iou_index = torch.argmax(IoUs)
return np.array([masks[max_iou_index].cpu().numpy()])
def point_prompt(self, points, pointlabel): # numpy 处理 def point_prompt(self, points, pointlabel): # numpy 处理
if self.results[0].masks is not None:
masks = self._format_results(self.results[0], 0) if os.path.isdir(self.source):
target_height = self.ori_img.shape[0] raise ValueError(f"'{self.source}' is a directory, not a valid source for this function.")
target_width = self.ori_img.shape[1] masks = self._format_results(self.results[0], 0)
h = masks[0]['segmentation'].shape[0] target_height, target_width = self.results[0].orig_shape
w = masks[0]['segmentation'].shape[1] h = masks[0]['segmentation'].shape[0]
if h != target_height or w != target_width: w = masks[0]['segmentation'].shape[1]
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] if h != target_height or w != target_width:
onemask = np.zeros((h, w)) points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
for i, annotation in enumerate(masks): onemask = np.zeros((h, w))
mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation for i, annotation in enumerate(masks):
for i, point in enumerate(points): mask = annotation['segmentation'] if isinstance(annotation, dict) else annotation
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: for i, point in enumerate(points):
onemask += mask if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: onemask += mask
onemask -= mask if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
onemask = onemask >= 1 onemask -= mask
return np.array([onemask]) onemask = onemask >= 1
self.results[0].masks.data = torch.tensor(np.array([onemask]))
return self.results
def text_prompt(self, text): def text_prompt(self, text):
format_results = self._format_results(self.results[0], 0) if self.results[0].masks is not None:
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) format_results = self._format_results(self.results[0], 0)
clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device) cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) clip_model, preprocess = self.clip.load('ViT-B/32', device=self.device)
max_idx = scores.argsort() scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
max_idx = max_idx[-1] max_idx = scores.argsort()
max_idx += sum(np.array(filter_id) <= int(max_idx)) max_idx = max_idx[-1]
return np.array([annotations[max_idx]['segmentation']]) max_idx += sum(np.array(filter_id) <= int(max_idx))
self.results[0].masks.data = torch.tensor(np.array([ann['segmentation'] for ann in annotations]))
return self.results
def everything_prompt(self): def everything_prompt(self):
return self.results[0].masks.data return self.results