mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Add SAM Predictor remove_small_regions
test (#4576)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
b4dca690d4
commit
e9f596430f
@ -75,6 +75,7 @@ def test_fastsam(task='segment', model=WEIGHTS_DIR / 'FastSAM-s.pt', data='coco8
|
|||||||
|
|
||||||
from ultralytics import FastSAM
|
from ultralytics import FastSAM
|
||||||
from ultralytics.models.fastsam import FastSAMPrompt
|
from ultralytics.models.fastsam import FastSAMPrompt
|
||||||
|
from ultralytics.models.sam import Predictor
|
||||||
|
|
||||||
# Create a FastSAM model
|
# Create a FastSAM model
|
||||||
sam_model = FastSAM(model) # or FastSAM-x.pt
|
sam_model = FastSAM(model) # or FastSAM-x.pt
|
||||||
@ -82,6 +83,9 @@ def test_fastsam(task='segment', model=WEIGHTS_DIR / 'FastSAM-s.pt', data='coco8
|
|||||||
# Run inference on an image
|
# Run inference on an image
|
||||||
everything_results = sam_model(source, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
|
everything_results = sam_model(source, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
|
||||||
|
|
||||||
|
# Remove small regions
|
||||||
|
new_masks, _ = Predictor.remove_small_regions(everything_results[0].masks.data, min_area=20)
|
||||||
|
|
||||||
# Everything prompt
|
# Everything prompt
|
||||||
prompt_process = FastSAMPrompt(source, everything_results, device='cpu')
|
prompt_process = FastSAMPrompt(source, everything_results, device='cpu')
|
||||||
ann = prompt_process.everything_prompt()
|
ann = prompt_process.everything_prompt()
|
||||||
|
@ -374,6 +374,10 @@ class Predictor(BasePredictor):
|
|||||||
masks (torch.Tensor): Masks, (N, H, W).
|
masks (torch.Tensor): Masks, (N, H, W).
|
||||||
min_area (int): Minimum area threshold.
|
min_area (int): Minimum area threshold.
|
||||||
nms_thresh (float): NMS threshold.
|
nms_thresh (float): NMS threshold.
|
||||||
|
Returns:
|
||||||
|
new_masks (torch.Tensor): New Masks, (N, H, W).
|
||||||
|
keep (List[int]): The indices of the new masks, which can be used to filter
|
||||||
|
the corresponding boxes.
|
||||||
"""
|
"""
|
||||||
if len(masks) == 0:
|
if len(masks) == 0:
|
||||||
return masks
|
return masks
|
||||||
@ -382,7 +386,7 @@ class Predictor(BasePredictor):
|
|||||||
new_masks = []
|
new_masks = []
|
||||||
scores = []
|
scores = []
|
||||||
for mask in masks:
|
for mask in masks:
|
||||||
mask = mask.cpu().numpy()
|
mask = mask.cpu().numpy().astype(np.uint8)
|
||||||
mask, changed = remove_small_regions(mask, min_area, mode='holes')
|
mask, changed = remove_small_regions(mask, min_area, mode='holes')
|
||||||
unchanged = not changed
|
unchanged = not changed
|
||||||
mask, changed = remove_small_regions(mask, min_area, mode='islands')
|
mask, changed = remove_small_regions(mask, min_area, mode='islands')
|
||||||
@ -402,9 +406,4 @@ class Predictor(BasePredictor):
|
|||||||
nms_thresh,
|
nms_thresh,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Only recalculate masks for masks that have changed
|
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
|
||||||
for i in keep:
|
|
||||||
if scores[i] == 0.0:
|
|
||||||
masks[i] = new_masks[i]
|
|
||||||
|
|
||||||
return masks[keep]
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user