mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 17:05:40 +08:00 
			
		
		
		
	Cleanup redundant SAM forward() methods (#4591)
				
					
				
			Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									47ab96dab6
								
							
						
					
					
						commit
						2567b288c9
					
				@ -27,12 +27,8 @@ def test_checks():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
 | 
					@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
 | 
				
			||||||
def test_train():
 | 
					def test_train():
 | 
				
			||||||
    YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, batch=-1, device=0)  # also test AutoBatch, requires imgsz>=64
 | 
					    device = 0 if CUDA_DEVICE_COUNT < 2 else [0, 1]
 | 
				
			||||||
 | 
					    YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, batch=-1, device=device)  # also test AutoBatch, requires imgsz>=64
 | 
				
			||||||
 | 
					 | 
				
			||||||
@pytest.mark.skipif(CUDA_DEVICE_COUNT < 2, reason=f'DDP is not available, {CUDA_DEVICE_COUNT} device(s) found')
 | 
					 | 
				
			||||||
def test_train_ddp():
 | 
					 | 
				
			||||||
    YOLO(MODEL).train(data=DATA, imgsz=64, epochs=1, device=[0, 1])  # requires imgsz>=64
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
 | 
					@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason='CUDA is not available')
 | 
				
			||||||
 | 
				
			|||||||
@ -119,7 +119,7 @@ class LoadStreams:
 | 
				
			|||||||
        # Wait until a frame is available in each buffer
 | 
					        # Wait until a frame is available in each buffer
 | 
				
			||||||
        while not all(self.imgs):
 | 
					        while not all(self.imgs):
 | 
				
			||||||
            if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'):  # q to quit
 | 
					            if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'):  # q to quit
 | 
				
			||||||
                cv2.destroyAllWindows()
 | 
					                self.close()
 | 
				
			||||||
                raise StopIteration
 | 
					                raise StopIteration
 | 
				
			||||||
            time.sleep(1 / min(self.fps))
 | 
					            time.sleep(1 / min(self.fps))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -6,11 +6,10 @@
 | 
				
			|||||||
# This source code is licensed under the license found in the
 | 
					# This source code is licensed under the license found in the
 | 
				
			||||||
# LICENSE file in the root directory of this source tree.
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Any, Dict, List, Tuple
 | 
					from typing import List
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch import nn
 | 
					from torch import nn
 | 
				
			||||||
from torch.nn import functional as F
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .decoders import MaskDecoder
 | 
					from .decoders import MaskDecoder
 | 
				
			||||||
from .encoders import ImageEncoderViT, PromptEncoder
 | 
					from .encoders import ImageEncoderViT, PromptEncoder
 | 
				
			||||||
@ -31,6 +30,9 @@ class Sam(nn.Module):
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        SAM predicts object masks from an image and input prompts.
 | 
					        SAM predicts object masks from an image and input prompts.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Note:
 | 
				
			||||||
 | 
					            All forward() operations moved to SAMPredictor.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
          image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
 | 
					          image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
 | 
				
			||||||
            efficient mask prediction.
 | 
					            efficient mask prediction.
 | 
				
			||||||
@ -45,109 +47,3 @@ class Sam(nn.Module):
 | 
				
			|||||||
        self.mask_decoder = mask_decoder
 | 
					        self.mask_decoder = mask_decoder
 | 
				
			||||||
        self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
 | 
					        self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
 | 
				
			||||||
        self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
 | 
					        self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def device(self) -> Any:
 | 
					 | 
				
			||||||
        return self.pixel_mean.device
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @torch.no_grad()
 | 
					 | 
				
			||||||
    def forward(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        batched_input: List[Dict[str, Any]],
 | 
					 | 
				
			||||||
        multimask_output: bool,
 | 
					 | 
				
			||||||
    ) -> List[Dict[str, torch.Tensor]]:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using
 | 
					 | 
				
			||||||
        SamPredictor is recommended over calling the model directly.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
          batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
 | 
					 | 
				
			||||||
              key can be excluded if it is not present.
 | 
					 | 
				
			||||||
              'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model.
 | 
					 | 
				
			||||||
              'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W).
 | 
					 | 
				
			||||||
              'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already
 | 
					 | 
				
			||||||
                transformed to the input frame of the model.
 | 
					 | 
				
			||||||
              'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN.
 | 
					 | 
				
			||||||
              'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of
 | 
					 | 
				
			||||||
                the model.
 | 
					 | 
				
			||||||
              'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW.
 | 
					 | 
				
			||||||
          multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
 | 
					 | 
				
			||||||
            mask.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
          (list(dict)): A list over input images, where each element is as dictionary with the following keys.
 | 
					 | 
				
			||||||
              'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of
 | 
					 | 
				
			||||||
                input prompts, C is determined by multimask_output, and (H, W) is the original size of the image.
 | 
					 | 
				
			||||||
              'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC.
 | 
					 | 
				
			||||||
              'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed
 | 
					 | 
				
			||||||
                as mask input to subsequent iterations of prediction.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0)
 | 
					 | 
				
			||||||
        image_embeddings = self.image_encoder(input_images)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        outputs = []
 | 
					 | 
				
			||||||
        for image_record, curr_embedding in zip(batched_input, image_embeddings):
 | 
					 | 
				
			||||||
            if 'point_coords' in image_record:
 | 
					 | 
				
			||||||
                points = (image_record['point_coords'], image_record['point_labels'])
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                points = None
 | 
					 | 
				
			||||||
            sparse_embeddings, dense_embeddings = self.prompt_encoder(
 | 
					 | 
				
			||||||
                points=points,
 | 
					 | 
				
			||||||
                boxes=image_record.get('boxes', None),
 | 
					 | 
				
			||||||
                masks=image_record.get('mask_inputs', None),
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            low_res_masks, iou_predictions = self.mask_decoder(
 | 
					 | 
				
			||||||
                image_embeddings=curr_embedding.unsqueeze(0),
 | 
					 | 
				
			||||||
                image_pe=self.prompt_encoder.get_dense_pe(),
 | 
					 | 
				
			||||||
                sparse_prompt_embeddings=sparse_embeddings,
 | 
					 | 
				
			||||||
                dense_prompt_embeddings=dense_embeddings,
 | 
					 | 
				
			||||||
                multimask_output=multimask_output,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            masks = self.postprocess_masks(
 | 
					 | 
				
			||||||
                low_res_masks,
 | 
					 | 
				
			||||||
                input_size=image_record['image'].shape[-2:],
 | 
					 | 
				
			||||||
                original_size=image_record['original_size'],
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            masks = masks > self.mask_threshold
 | 
					 | 
				
			||||||
            outputs.append({
 | 
					 | 
				
			||||||
                'masks': masks,
 | 
					 | 
				
			||||||
                'iou_predictions': iou_predictions,
 | 
					 | 
				
			||||||
                'low_res_logits': low_res_masks, })
 | 
					 | 
				
			||||||
        return outputs
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def postprocess_masks(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        masks: torch.Tensor,
 | 
					 | 
				
			||||||
        input_size: Tuple[int, ...],
 | 
					 | 
				
			||||||
        original_size: Tuple[int, ...],
 | 
					 | 
				
			||||||
    ) -> torch.Tensor:
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        Remove padding and upscale masks to the original image size.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Args:
 | 
					 | 
				
			||||||
          masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format.
 | 
					 | 
				
			||||||
          input_size (tuple(int, int)): The size of the model input image, in (H, W) format. Used to remove padding.
 | 
					 | 
				
			||||||
          original_size (tuple(int, int)): The original image size before resizing for input to the model, in (H, W).
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        Returns:
 | 
					 | 
				
			||||||
          (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size.
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        masks = F.interpolate(
 | 
					 | 
				
			||||||
            masks,
 | 
					 | 
				
			||||||
            (self.image_encoder.img_size, self.image_encoder.img_size),
 | 
					 | 
				
			||||||
            mode='bilinear',
 | 
					 | 
				
			||||||
            align_corners=False,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        masks = masks[..., :input_size[0], :input_size[1]]
 | 
					 | 
				
			||||||
        return F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
 | 
					 | 
				
			||||||
        """Normalize pixel values and pad to a square input."""
 | 
					 | 
				
			||||||
        # Normalize colors
 | 
					 | 
				
			||||||
        x = (x - self.pixel_mean) / self.pixel_std
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Pad
 | 
					 | 
				
			||||||
        h, w = x.shape[-2:]
 | 
					 | 
				
			||||||
        padh = self.image_encoder.img_size - h
 | 
					 | 
				
			||||||
        padw = self.image_encoder.img_size - w
 | 
					 | 
				
			||||||
        return F.pad(x, (0, padw, 0, padh))
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -519,9 +519,13 @@ def cuda_device_count() -> int:
 | 
				
			|||||||
        # Run the nvidia-smi command and capture its output
 | 
					        # Run the nvidia-smi command and capture its output
 | 
				
			||||||
        output = subprocess.check_output(['nvidia-smi', '--query-gpu=count', '--format=csv,noheader,nounits'],
 | 
					        output = subprocess.check_output(['nvidia-smi', '--query-gpu=count', '--format=csv,noheader,nounits'],
 | 
				
			||||||
                                         encoding='utf-8')
 | 
					                                         encoding='utf-8')
 | 
				
			||||||
        return int(output.strip())
 | 
					
 | 
				
			||||||
    except (subprocess.CalledProcessError, FileNotFoundError):
 | 
					        # Take the first line and strip any leading/trailing white space
 | 
				
			||||||
        # If the command fails or nvidia-smi is not found, assume no GPUs are available
 | 
					        first_line = output.strip().split('\n')[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return int(first_line)
 | 
				
			||||||
 | 
					    except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
 | 
				
			||||||
 | 
					        # If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available
 | 
				
			||||||
        return 0
 | 
					        return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user