mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Buffered Mosaic for reduced HDD reads (#2791)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
dada5b73c4
commit
07b57c03c8
@ -93,7 +93,7 @@ class BaseMixTransform:
|
|||||||
indexes = [indexes]
|
indexes = [indexes]
|
||||||
|
|
||||||
# Get images information will be used for Mosaic or MixUp
|
# Get images information will be used for Mosaic or MixUp
|
||||||
mix_labels = [self.dataset.get_label_info(i) for i in indexes]
|
mix_labels = [self.dataset.get_image_and_label(i) for i in indexes]
|
||||||
|
|
||||||
if self.pre_transform is not None:
|
if self.pre_transform is not None:
|
||||||
for i, data in enumerate(mix_labels):
|
for i, data in enumerate(mix_labels):
|
||||||
@ -135,11 +135,14 @@ class Mosaic(BaseMixTransform):
|
|||||||
super().__init__(dataset=dataset, p=p)
|
super().__init__(dataset=dataset, p=p)
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.imgsz = imgsz
|
self.imgsz = imgsz
|
||||||
self.border = [-imgsz // 2, -imgsz // 2] if n == 4 else [-imgsz, -imgsz]
|
self.border = (-imgsz // 2, -imgsz // 2) # width, height
|
||||||
self.n = n
|
self.n = n
|
||||||
|
|
||||||
def get_indexes(self):
|
def get_indexes(self, buffer=True):
|
||||||
"""Return a list of random indexes from the dataset."""
|
"""Return a list of random indexes from the dataset."""
|
||||||
|
if buffer: # select images from buffer
|
||||||
|
return random.choices(list(self.dataset.buffer), k=self.n - 1)
|
||||||
|
else: # select any images
|
||||||
return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
|
return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
|
||||||
|
|
||||||
def _mix_transform(self, labels):
|
def _mix_transform(self, labels):
|
||||||
@ -224,10 +227,12 @@ class Mosaic(BaseMixTransform):
|
|||||||
img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
|
img9[y1:y2, x1:x2] = img[y1 - padh:, x1 - padw:] # img9[ymin:ymax, xmin:xmax]
|
||||||
hp, wp = h, w # height, width previous for next iteration
|
hp, wp = h, w # height, width previous for next iteration
|
||||||
|
|
||||||
labels_patch = self._update_labels(labels_patch, padw, padh)
|
# Labels assuming imgsz*2 mosaic size
|
||||||
|
labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1])
|
||||||
mosaic_labels.append(labels_patch)
|
mosaic_labels.append(labels_patch)
|
||||||
final_labels = self._cat_labels(mosaic_labels)
|
final_labels = self._cat_labels(mosaic_labels)
|
||||||
final_labels['img'] = img9
|
|
||||||
|
final_labels['img'] = img9[-self.border[0]:self.border[0], -self.border[1]:self.border[1]]
|
||||||
return final_labels
|
return final_labels
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -245,18 +250,20 @@ class Mosaic(BaseMixTransform):
|
|||||||
return {}
|
return {}
|
||||||
cls = []
|
cls = []
|
||||||
instances = []
|
instances = []
|
||||||
|
imgsz = self.imgsz * 2 # mosaic imgsz
|
||||||
for labels in mosaic_labels:
|
for labels in mosaic_labels:
|
||||||
cls.append(labels['cls'])
|
cls.append(labels['cls'])
|
||||||
instances.append(labels['instances'])
|
instances.append(labels['instances'])
|
||||||
final_labels = {
|
final_labels = {
|
||||||
'im_file': mosaic_labels[0]['im_file'],
|
'im_file': mosaic_labels[0]['im_file'],
|
||||||
'ori_shape': mosaic_labels[0]['ori_shape'],
|
'ori_shape': mosaic_labels[0]['ori_shape'],
|
||||||
'resized_shape': (self.imgsz * 2, self.imgsz * 2),
|
'resized_shape': (imgsz, imgsz),
|
||||||
'cls': np.concatenate(cls, 0),
|
'cls': np.concatenate(cls, 0),
|
||||||
'instances': Instances.concatenate(instances, axis=0),
|
'instances': Instances.concatenate(instances, axis=0),
|
||||||
'mosaic_border': self.border} # final_labels
|
'mosaic_border': self.border} # final_labels
|
||||||
clip_size = self.imgsz * (2 if self.n == 4 else 3)
|
final_labels['instances'].clip(imgsz, imgsz)
|
||||||
final_labels['instances'].clip(clip_size, clip_size)
|
good = final_labels['instances'].remove_zero_area_boxes()
|
||||||
|
final_labels['cls'] = final_labels['cls'][good]
|
||||||
return final_labels
|
return final_labels
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ class BaseDataset(Dataset):
|
|||||||
# Cache stuff
|
# Cache stuff
|
||||||
if cache == 'ram' and not self.check_cache_ram():
|
if cache == 'ram' and not self.check_cache_ram():
|
||||||
cache = False
|
cache = False
|
||||||
self.ims = [None] * self.ni
|
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
|
||||||
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
|
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
|
||||||
if cache:
|
if cache:
|
||||||
self.cache_images(cache)
|
self.cache_images(cache)
|
||||||
@ -88,6 +88,10 @@ class BaseDataset(Dataset):
|
|||||||
# Transforms
|
# Transforms
|
||||||
self.transforms = self.build_transforms(hyp=hyp)
|
self.transforms = self.build_transforms(hyp=hyp)
|
||||||
|
|
||||||
|
# Buffer thread for mosaic images
|
||||||
|
self.buffer = [] # buffer size = batch size
|
||||||
|
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
|
||||||
|
|
||||||
def get_img_files(self, img_path):
|
def get_img_files(self, img_path):
|
||||||
"""Read image files."""
|
"""Read image files."""
|
||||||
try:
|
try:
|
||||||
@ -147,13 +151,22 @@ class BaseDataset(Dataset):
|
|||||||
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
|
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
|
||||||
im = cv2.resize(im, (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)),
|
im = cv2.resize(im, (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)),
|
||||||
interpolation=interp)
|
interpolation=interp)
|
||||||
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
|
|
||||||
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
|
# Add to buffer if training with augmentations
|
||||||
|
if self.augment:
|
||||||
|
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
|
||||||
|
self.buffer.append(i)
|
||||||
|
if len(self.buffer) >= self.max_buffer_length:
|
||||||
|
j = self.buffer.pop(0)
|
||||||
|
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
|
||||||
|
|
||||||
|
return im, (h0, w0), im.shape[:2]
|
||||||
|
|
||||||
|
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
||||||
|
|
||||||
def cache_images(self, cache):
|
def cache_images(self, cache):
|
||||||
"""Cache images to memory or disk."""
|
"""Cache images to memory or disk."""
|
||||||
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||||
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
|
||||||
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
|
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
|
||||||
with ThreadPool(NUM_THREADS) as pool:
|
with ThreadPool(NUM_THREADS) as pool:
|
||||||
results = pool.imap(fcn, range(self.ni))
|
results = pool.imap(fcn, range(self.ni))
|
||||||
@ -218,9 +231,9 @@ class BaseDataset(Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
"""Returns transformed label information for given index."""
|
"""Returns transformed label information for given index."""
|
||||||
return self.transforms(self.get_label_info(index))
|
return self.transforms(self.get_image_and_label(index))
|
||||||
|
|
||||||
def get_label_info(self, index):
|
def get_image_and_label(self, index):
|
||||||
"""Get and return label information from the dataset."""
|
"""Get and return label information from the dataset."""
|
||||||
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
|
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
|
||||||
label.pop('shape', None) # shape is for rect, remove it
|
label.pop('shape', None) # shape is for rect, remove it
|
||||||
@ -229,8 +242,7 @@ class BaseDataset(Dataset):
|
|||||||
label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
|
label['resized_shape'][1] / label['ori_shape'][1]) # for evaluation
|
||||||
if self.rect:
|
if self.rect:
|
||||||
label['rect_shape'] = self.batch_shapes[self.batch[index]]
|
label['rect_shape'] = self.batch_shapes[self.batch[index]]
|
||||||
label = self.update_labels_info(label)
|
return self.update_labels_info(label)
|
||||||
return label
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""Returns the length of the labels list for the dataset."""
|
"""Returns the length of the labels list for the dataset."""
|
||||||
|
@ -326,10 +326,20 @@ class Instances:
|
|||||||
self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
|
self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
|
||||||
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
|
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
|
||||||
|
|
||||||
|
def remove_zero_area_boxes(self):
|
||||||
|
"""Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height. This removes them."""
|
||||||
|
good = self._bboxes.areas() > 0
|
||||||
|
if not all(good):
|
||||||
|
self._bboxes = Bboxes(self._bboxes.bboxes[good], format=self._bboxes.format)
|
||||||
|
if len(self.segments):
|
||||||
|
self.segments = self.segments[good]
|
||||||
|
if self.keypoints is not None:
|
||||||
|
self.keypoints = self.keypoints[good]
|
||||||
|
return good
|
||||||
|
|
||||||
def update(self, bboxes, segments=None, keypoints=None):
|
def update(self, bboxes, segments=None, keypoints=None):
|
||||||
"""Updates instance variables."""
|
"""Updates instance variables."""
|
||||||
new_bboxes = Bboxes(bboxes, format=self._bboxes.format)
|
self._bboxes = Bboxes(bboxes, format=self._bboxes.format)
|
||||||
self._bboxes = new_bboxes
|
|
||||||
if segments is not None:
|
if segments is not None:
|
||||||
self.segments = segments
|
self.segments = segments
|
||||||
if keypoints is not None:
|
if keypoints is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user