mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Cleanup redundant SAM forward()
methods (#4591)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
47ab96dab6
commit
2567b288c9
@ -27,12 +27,8 @@ def test_checks():
|
|||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
||||||
def test_train():
|
def test_train():
|
||||||
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, batch=-1, device=0) # also test AutoBatch, requires imgsz>=64
|
device = 0 if CUDA_DEVICE_COUNT < 2 else [0, 1]
|
||||||
|
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, batch=-1, device=device) # also test AutoBatch, requires imgsz>=64
|
||||||
|
|
||||||
@pytest.mark.skipif(CUDA_DEVICE_COUNT < 2, reason=f'DDP is not available, {CUDA_DEVICE_COUNT} device(s) found')
|
|
||||||
def test_train_ddp():
|
|
||||||
YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, device=[0, 1]) # requires imgsz>=64
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
|
||||||
|
@ -119,7 +119,7 @@ class LoadStreams:
|
|||||||
# Wait until a frame is available in each buffer
|
# Wait until a frame is available in each buffer
|
||||||
while not all(self.imgs):
|
while not all(self.imgs):
|
||||||
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
|
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
|
||||||
cv2.destroyAllWindows()
|
self.close()
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
time.sleep(1 / min(self.fps))
|
time.sleep(1 / min(self.fps))
|
||||||
|
|
||||||
|
@ -6,11 +6,10 @@
|
|||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
from .decoders import MaskDecoder
|
from .decoders import MaskDecoder
|
||||||
from .encoders import ImageEncoderViT, PromptEncoder
|
from .encoders import ImageEncoderViT, PromptEncoder
|
||||||
@ -31,6 +30,9 @@ class Sam(nn.Module):
|
|||||||
"""
|
"""
|
||||||
SAM predicts object masks from an image and input prompts.
|
SAM predicts object masks from an image and input prompts.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
All forward() operations moved to SAMPredictor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
|
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
|
||||||
efficient mask prediction.
|
efficient mask prediction.
|
||||||
@ -45,109 +47,3 @@ class Sam(nn.Module):
|
|||||||
self.mask_decoder = mask_decoder
|
self.mask_decoder = mask_decoder
|
||||||
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
||||||
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> Any:
|
|
||||||
return self.pixel_mean.device
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
batched_input: List[Dict[str, Any]],
|
|
||||||
multimask_output: bool,
|
|
||||||
) -> List[Dict[str, torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using
|
|
||||||
SamPredictor is recommended over calling the model directly.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
|
|
||||||
key can be excluded if it is not present.
|
|
||||||
'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model.
|
|
||||||
'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W).
|
|
||||||
'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already
|
|
||||||
transformed to the input frame of the model.
|
|
||||||
'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN.
|
|
||||||
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of
|
|
||||||
the model.
|
|
||||||
'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW.
|
|
||||||
multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
|
|
||||||
mask.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(list(dict)): A list over input images, where each element is as dictionary with the following keys.
|
|
||||||
'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of
|
|
||||||
input prompts, C is determined by multimask_output, and (H, W) is the original size of the image.
|
|
||||||
'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC.
|
|
||||||
'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed
|
|
||||||
as mask input to subsequent iterations of prediction.
|
|
||||||
"""
|
|
||||||
input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0)
|
|
||||||
image_embeddings = self.image_encoder(input_images)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
|
||||||
if 'point_coords' in image_record:
|
|
||||||
points = (image_record['point_coords'], image_record['point_labels'])
|
|
||||||
else:
|
|
||||||
points = None
|
|
||||||
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
|
||||||
points=points,
|
|
||||||
boxes=image_record.get('boxes', None),
|
|
||||||
masks=image_record.get('mask_inputs', None),
|
|
||||||
)
|
|
||||||
low_res_masks, iou_predictions = self.mask_decoder(
|
|
||||||
image_embeddings=curr_embedding.unsqueeze(0),
|
|
||||||
image_pe=self.prompt_encoder.get_dense_pe(),
|
|
||||||
sparse_prompt_embeddings=sparse_embeddings,
|
|
||||||
dense_prompt_embeddings=dense_embeddings,
|
|
||||||
multimask_output=multimask_output,
|
|
||||||
)
|
|
||||||
masks = self.postprocess_masks(
|
|
||||||
low_res_masks,
|
|
||||||
input_size=image_record['image'].shape[-2:],
|
|
||||||
original_size=image_record['original_size'],
|
|
||||||
)
|
|
||||||
masks = masks > self.mask_threshold
|
|
||||||
outputs.append({
|
|
||||||
'masks': masks,
|
|
||||||
'iou_predictions': iou_predictions,
|
|
||||||
'low_res_logits': low_res_masks, })
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def postprocess_masks(
|
|
||||||
self,
|
|
||||||
masks: torch.Tensor,
|
|
||||||
input_size: Tuple[int, ...],
|
|
||||||
original_size: Tuple[int, ...],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Remove padding and upscale masks to the original image size.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format.
|
|
||||||
input_size (tuple(int, int)): The size of the model input image, in (H, W) format. Used to remove padding.
|
|
||||||
original_size (tuple(int, int)): The original image size before resizing for input to the model, in (H, W).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size.
|
|
||||||
"""
|
|
||||||
masks = F.interpolate(
|
|
||||||
masks,
|
|
||||||
(self.image_encoder.img_size, self.image_encoder.img_size),
|
|
||||||
mode='bilinear',
|
|
||||||
align_corners=False,
|
|
||||||
)
|
|
||||||
masks = masks[..., :input_size[0], :input_size[1]]
|
|
||||||
return F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
|
|
||||||
|
|
||||||
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Normalize pixel values and pad to a square input."""
|
|
||||||
# Normalize colors
|
|
||||||
x = (x - self.pixel_mean) / self.pixel_std
|
|
||||||
|
|
||||||
# Pad
|
|
||||||
h, w = x.shape[-2:]
|
|
||||||
padh = self.image_encoder.img_size - h
|
|
||||||
padw = self.image_encoder.img_size - w
|
|
||||||
return F.pad(x, (0, padw, 0, padh))
|
|
||||||
|
@ -519,9 +519,13 @@ def cuda_device_count() -> int:
|
|||||||
# Run the nvidia-smi command and capture its output
|
# Run the nvidia-smi command and capture its output
|
||||||
output = subprocess.check_output(['nvidia-smi', '--query-gpu=count', '--format=csv,noheader,nounits'],
|
output = subprocess.check_output(['nvidia-smi', '--query-gpu=count', '--format=csv,noheader,nounits'],
|
||||||
encoding='utf-8')
|
encoding='utf-8')
|
||||||
return int(output.strip())
|
|
||||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
# Take the first line and strip any leading/trailing white space
|
||||||
# If the command fails or nvidia-smi is not found, assume no GPUs are available
|
first_line = output.strip().split('\n')[0]
|
||||||
|
|
||||||
|
return int(first_line)
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
|
||||||
|
# If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user