Clean up unused imgsz (#7771)

This commit is contained in:
Laughing 2024-01-24 01:50:01 +08:00 committed by GitHub
parent f56dd0f48e
commit 67ae86f006
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 19 deletions

View File

@ -150,13 +150,12 @@ def check_source(source):
return source, webcam, screenshot, from_img, in_memory, tensor return source, webcam, screenshot, from_img, in_memory, tensor
def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False): def load_inference_source(source=None, vid_stride=1, buffer=False):
""" """
Loads an inference source for object detection and applies necessary transformations. Loads an inference source for object detection and applies necessary transformations.
Args: Args:
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference. source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
imgsz (int, optional): The size of the image for inference. Default is 640.
vid_stride (int, optional): The frame interval for video sources. Default is 1. vid_stride (int, optional): The frame interval for video sources. Default is 1.
buffer (bool, optional): Determined whether stream frames will be buffered. Default is False. buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.
@ -172,13 +171,13 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False):
elif in_memory: elif in_memory:
dataset = source dataset = source
elif webcam: elif webcam:
dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride, buffer=buffer) dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer)
elif screenshot: elif screenshot:
dataset = LoadScreenshots(source, imgsz=imgsz) dataset = LoadScreenshots(source)
elif from_img: elif from_img:
dataset = LoadPilAndNumpy(source, imgsz=imgsz) dataset = LoadPilAndNumpy(source)
else: else:
dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride) dataset = LoadImages(source, vid_stride=vid_stride)
# Attach source types to the dataset # Attach source types to the dataset
setattr(dataset, "source_type", source_type) setattr(dataset, "source_type", source_type)

View File

@ -38,7 +38,6 @@ class LoadStreams:
Attributes: Attributes:
sources (str): The source input paths or URLs for the video streams. sources (str): The source input paths or URLs for the video streams.
imgsz (int): The image size for processing, defaults to 640.
vid_stride (int): Video frame-rate stride, defaults to 1. vid_stride (int): Video frame-rate stride, defaults to 1.
buffer (bool): Whether to buffer input streams, defaults to False. buffer (bool): Whether to buffer input streams, defaults to False.
running (bool): Flag to indicate if the streaming thread is running. running (bool): Flag to indicate if the streaming thread is running.
@ -60,13 +59,12 @@ class LoadStreams:
__len__: Return the length of the sources object. __len__: Return the length of the sources object.
""" """
def __init__(self, sources="file.streams", imgsz=640, vid_stride=1, buffer=False): def __init__(self, sources="file.streams", vid_stride=1, buffer=False):
"""Initialize instance variables and check for consistent input stream shapes.""" """Initialize instance variables and check for consistent input stream shapes."""
torch.backends.cudnn.benchmark = True # faster for fixed-size inference torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.buffer = buffer # buffer input streams self.buffer = buffer # buffer input streams
self.running = True # running flag for Thread self.running = True # running flag for Thread
self.mode = "stream" self.mode = "stream"
self.imgsz = imgsz
self.vid_stride = vid_stride # video frame-rate stride self.vid_stride = vid_stride # video frame-rate stride
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources] sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
@ -193,7 +191,6 @@ class LoadScreenshots:
Attributes: Attributes:
source (str): The source input indicating which screen to capture. source (str): The source input indicating which screen to capture.
imgsz (int): The image size for processing, defaults to 640.
screen (int): The screen number to capture. screen (int): The screen number to capture.
left (int): The left coordinate for screen capture area. left (int): The left coordinate for screen capture area.
top (int): The top coordinate for screen capture area. top (int): The top coordinate for screen capture area.
@ -210,7 +207,7 @@ class LoadScreenshots:
__next__: Captures the next screenshot and returns it. __next__: Captures the next screenshot and returns it.
""" """
def __init__(self, source, imgsz=640): def __init__(self, source):
"""Source = [screen_number left top width height] (pixels).""" """Source = [screen_number left top width height] (pixels)."""
check_requirements("mss") check_requirements("mss")
import mss # noqa import mss # noqa
@ -223,7 +220,6 @@ class LoadScreenshots:
left, top, width, height = (int(x) for x in params) left, top, width, height = (int(x) for x in params)
elif len(params) == 5: elif len(params) == 5:
self.screen, left, top, width, height = (int(x) for x in params) self.screen, left, top, width, height = (int(x) for x in params)
self.imgsz = imgsz
self.mode = "stream" self.mode = "stream"
self.frame = 0 self.frame = 0
self.sct = mss.mss() self.sct = mss.mss()
@ -258,7 +254,6 @@ class LoadImages:
various formats, including single image files, video files, and lists of image and video paths. various formats, including single image files, video files, and lists of image and video paths.
Attributes: Attributes:
imgsz (int): Image size, defaults to 640.
files (list): List of image and video file paths. files (list): List of image and video file paths.
nf (int): Total number of files (images and videos). nf (int): Total number of files (images and videos).
video_flag (list): Flags indicating whether a file is a video (True) or an image (False). video_flag (list): Flags indicating whether a file is a video (True) or an image (False).
@ -274,7 +269,7 @@ class LoadImages:
_new_video(path): Create a new cv2.VideoCapture object for a given video path. _new_video(path): Create a new cv2.VideoCapture object for a given video path.
""" """
def __init__(self, path, imgsz=640, vid_stride=1): def __init__(self, path, vid_stride=1):
"""Initialize the Dataloader and raise FileNotFoundError if file not found.""" """Initialize the Dataloader and raise FileNotFoundError if file not found."""
parent = None parent = None
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
@ -298,7 +293,6 @@ class LoadImages:
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS] videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
ni, nv = len(images), len(videos) ni, nv = len(images), len(videos)
self.imgsz = imgsz
self.files = images + videos self.files = images + videos
self.nf = ni + nv # number of files self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv self.video_flag = [False] * ni + [True] * nv
@ -377,7 +371,6 @@ class LoadPilAndNumpy:
Attributes: Attributes:
paths (list): List of image paths or autogenerated filenames. paths (list): List of image paths or autogenerated filenames.
im0 (list): List of images stored as Numpy arrays. im0 (list): List of images stored as Numpy arrays.
imgsz (int): Image size, defaults to 640.
mode (str): Type of data being processed, defaults to 'image'. mode (str): Type of data being processed, defaults to 'image'.
bs (int): Batch size, equivalent to the length of `im0`. bs (int): Batch size, equivalent to the length of `im0`.
count (int): Counter for iteration, initialized at 0 during `__iter__()`. count (int): Counter for iteration, initialized at 0 during `__iter__()`.
@ -386,13 +379,12 @@ class LoadPilAndNumpy:
_single_check(im): Validate and format a single image to a Numpy array. _single_check(im): Validate and format a single image to a Numpy array.
""" """
def __init__(self, im0, imgsz=640): def __init__(self, im0):
"""Initialize PIL and Numpy Dataloader.""" """Initialize PIL and Numpy Dataloader."""
if not isinstance(im0, list): if not isinstance(im0, list):
im0 = [im0] im0 = [im0]
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)] self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
self.im0 = [self._single_check(im) for im in im0] self.im0 = [self._single_check(im) for im in im0]
self.imgsz = imgsz
self.mode = "image" self.mode = "image"
# Generate fake paths # Generate fake paths
self.bs = len(self.im0) self.bs = len(self.im0)

View File

@ -226,7 +226,7 @@ class BasePredictor:
else None else None
) )
self.dataset = load_inference_source( self.dataset = load_inference_source(
source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer source=source, vid_stride=self.args.vid_stride, buffer=self.args.stream_buffer
) )
self.source_type = self.dataset.source_type self.source_type = self.dataset.source_type
if not getattr(self, "stream", True) and ( if not getattr(self, "stream", True) and (