mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Fix dataloader2 (#35)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
523eff99e2
commit
c617ee1c79
@ -26,4 +26,4 @@ flipud: 0.0 # image flip up-down (probability)
|
||||
fliplr: 0.5 # image flip left-right (probability)
|
||||
mosaic: 1.0 # image mosaic (probability)
|
||||
mixup: 0.0 # image mixup (probability)
|
||||
copy_paste: 0.0 # segment copy-paste (probability)
|
||||
copy_paste: 0.5 # segment copy-paste (probability)
|
||||
|
@ -67,15 +67,17 @@ def plot_keypoint(img, keypoints, color, tl):
|
||||
with open("ultralytics/tests/data/dataloader/hyp_test.yaml") as f:
|
||||
hyp = OmegaConf.load(f)
|
||||
|
||||
dataloader, dataset = build_dataloader(
|
||||
|
||||
def test(augment, rect):
|
||||
dataloader, _ = build_dataloader(
|
||||
img_path="/d/dataset/COCO/images/val2017",
|
||||
img_size=640,
|
||||
label_path=None,
|
||||
cache=False,
|
||||
hyp=hyp,
|
||||
augment=False,
|
||||
augment=augment,
|
||||
prefix="",
|
||||
rect=False,
|
||||
rect=rect,
|
||||
batch_size=4,
|
||||
stride=32,
|
||||
pad=0.5,
|
||||
@ -108,7 +110,17 @@ for d in dataloader:
|
||||
y2 = y + h / 2
|
||||
c = int(cls[i][0])
|
||||
# print(x1, y1, x2, y2)
|
||||
plot_one_box([int(x1), int(y1), int(x2), int(y2)], img, keypoints=keypoints[i], label=f"{c}", color=colors(c))
|
||||
plot_one_box([int(x1), int(y1), int(x2), int(y2)],
|
||||
img,
|
||||
keypoints=keypoints[i],
|
||||
label=f"{c}",
|
||||
color=colors(c))
|
||||
cv2.imshow("p", img)
|
||||
if cv2.waitKey(0) == ord("q"):
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test(augment=True, rect=False)
|
||||
test(augment=False, rect=True)
|
||||
test(augment=False, rect=False)
|
||||
|
@ -55,15 +55,17 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None):
|
||||
with open("ultralytics/tests/data/dataloader/hyp_test.yaml") as f:
|
||||
hyp = OmegaConf.load(f)
|
||||
|
||||
dataloader, dataset = build_dataloader(
|
||||
|
||||
def test(augment, rect):
|
||||
dataloader, _ = build_dataloader(
|
||||
img_path="/d/dataset/COCO/coco128-seg/images",
|
||||
img_size=640,
|
||||
label_path=None,
|
||||
cache=False,
|
||||
hyp=hyp,
|
||||
augment=False,
|
||||
augment=augment,
|
||||
prefix="",
|
||||
rect=False,
|
||||
rect=rect,
|
||||
batch_size=4,
|
||||
stride=32,
|
||||
pad=0.5,
|
||||
@ -72,6 +74,14 @@ dataloader, dataset = build_dataloader(
|
||||
)
|
||||
|
||||
for d in dataloader:
|
||||
# info
|
||||
im_file = d["im_file"]
|
||||
ori_shape = d["ori_shape"]
|
||||
resize_shape = d["resized_shape"]
|
||||
print(ori_shape, resize_shape)
|
||||
print(im_file)
|
||||
|
||||
# labels
|
||||
idx = 1 # show which image inside one batch
|
||||
img = d["img"][idx].numpy()
|
||||
img = np.ascontiguousarray(img.transpose(1, 2, 0))
|
||||
@ -110,3 +120,9 @@ for d in dataloader:
|
||||
cv2.imshow("p", img)
|
||||
if cv2.waitKey(0) == ord("q"):
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test(augment=True, rect=False)
|
||||
test(augment=False, rect=True)
|
||||
test(augment=False, rect=False)
|
||||
|
@ -184,7 +184,7 @@ class Mosaic(BaseMixTransform):
|
||||
cls.append(labels["cls"])
|
||||
instances.append(labels["instances"])
|
||||
final_labels = {
|
||||
"ori_shape": (self.img_size * 2, self.img_size * 2),
|
||||
"ori_shape": mosaic_labels[0]["ori_shape"],
|
||||
"resized_shape": (self.img_size * 2, self.img_size * 2),
|
||||
"im_file": mosaic_labels[0]["im_file"],
|
||||
"cls": np.concatenate(cls, 0)}
|
||||
@ -351,7 +351,7 @@ class RandomPerspective:
|
||||
"""
|
||||
img = labels["img"]
|
||||
cls = labels["cls"]
|
||||
instances = labels["instances"]
|
||||
instances = labels.pop("instances")
|
||||
# make sure the coord formats are right
|
||||
instances.convert_bbox(format="xyxy")
|
||||
instances.denormalize(*img.shape[:2][::-1])
|
||||
@ -372,6 +372,7 @@ class RandomPerspective:
|
||||
if keypoints is not None:
|
||||
keypoints = self.apply_keypoints(keypoints, M)
|
||||
new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
|
||||
# clip
|
||||
new_instances.clip(*self.size)
|
||||
|
||||
# filter instances
|
||||
@ -381,9 +382,9 @@ class RandomPerspective:
|
||||
box2=new_instances.bboxes.T,
|
||||
area_thr=0.01 if len(segments) else 0.10)
|
||||
labels["instances"] = new_instances[i]
|
||||
# clip
|
||||
labels["cls"] = cls[i]
|
||||
labels["img"] = img
|
||||
labels["resized_shape"] = img.shape[:2]
|
||||
return labels
|
||||
|
||||
def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
|
||||
@ -430,7 +431,7 @@ class RandomFlip:
|
||||
|
||||
def __call__(self, labels):
|
||||
img = labels["img"]
|
||||
instances = labels["instances"]
|
||||
instances = labels.pop("instances")
|
||||
instances.convert_bbox(format="xywh")
|
||||
h, w = img.shape[:2]
|
||||
h = 1 if instances.normalized else h
|
||||
@ -439,13 +440,11 @@ class RandomFlip:
|
||||
# Flip up-down
|
||||
if self.direction == "vertical" and random.random() < self.p:
|
||||
img = np.flipud(img)
|
||||
img = np.ascontiguousarray(img)
|
||||
instances.flipud(h)
|
||||
if self.direction == "horizontal" and random.random() < self.p:
|
||||
img = np.fliplr(img)
|
||||
img = np.ascontiguousarray(img)
|
||||
instances.fliplr(w)
|
||||
labels["img"] = img
|
||||
labels["img"] = np.ascontiguousarray(img)
|
||||
labels["instances"] = instances
|
||||
return labels
|
||||
|
||||
@ -463,7 +462,7 @@ class LetterBox:
|
||||
def __call__(self, labels={}, image=None):
|
||||
img = image or labels["img"]
|
||||
shape = img.shape[:2] # current shape [height, width]
|
||||
new_shape = labels.get("rect_shape", self.new_shape)
|
||||
new_shape = labels.pop("rect_shape", self.new_shape)
|
||||
if isinstance(new_shape, int):
|
||||
new_shape = (new_shape, new_shape)
|
||||
|
||||
@ -495,6 +494,7 @@ class LetterBox:
|
||||
|
||||
labels = self._update_labels(labels, ratio, dw, dh)
|
||||
labels["img"] = img
|
||||
labels["resized_shape"] = new_shape
|
||||
return labels
|
||||
|
||||
def _update_labels(self, labels, ratio, padw, padh):
|
||||
@ -515,26 +515,21 @@ class CopyPaste:
|
||||
# Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
|
||||
im = labels["img"]
|
||||
cls = labels["cls"]
|
||||
bboxes = labels["instances"].bboxes
|
||||
segments = labels["instances"].segments # n, 1000, 2
|
||||
keypoints = labels["instances"].keypoints
|
||||
if self.p and len(segments):
|
||||
n = len(segments)
|
||||
instances = labels.pop("instances")
|
||||
instances.convert_bbox(format="xyxy")
|
||||
if self.p and len(instances.segments):
|
||||
n = len(instances)
|
||||
h, w, _ = im.shape # height, width, channels
|
||||
im_new = np.zeros(im.shape, np.uint8)
|
||||
# TODO: this implement can be parallel since segments are ndarray, also might work with Instances inside
|
||||
for j in random.sample(range(n), k=round(self.p * n)):
|
||||
c, b, s = cls[j], bboxes[j], segments[j]
|
||||
box = w - b[2], b[1], w - b[0], b[3]
|
||||
ioa = bbox_ioa(box, bboxes) # intersection over area
|
||||
if (ioa < 0.30).all(): # allow 30% obscuration of existing labels
|
||||
bboxes = np.concatenate((bboxes, [box]), 0)
|
||||
cls = np.concatenate((cls, c[None]), axis=0)
|
||||
segments = np.concatenate((segments, np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)[None]), 0)
|
||||
if keypoints is not None:
|
||||
keypoints = np.concatenate(
|
||||
(keypoints, np.concatenate((w - keypoints[j][:, 0:1], keypoints[j][:, 1:2]), 1)), 0)
|
||||
cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
|
||||
j = random.sample(range(n), k=round(self.p * n))
|
||||
c, instance = cls[j], instances[j]
|
||||
instance.fliplr(w)
|
||||
ioa = bbox_ioa(instance.bboxes, instances.bboxes) # intersection over area, (N, M)
|
||||
i = (ioa < 0.30).all(1) # (N, )
|
||||
if i.sum():
|
||||
cls = np.concatenate((cls, c[i]), axis=0)
|
||||
instances = Instances.concatenate((instances, instance[i]), axis=0)
|
||||
cv2.drawContours(im_new, instances.segments[j][i].astype(np.int32), -1, (255, 255, 255), cv2.FILLED)
|
||||
|
||||
result = cv2.bitwise_and(src1=im, src2=im_new)
|
||||
result = cv2.flip(result, 1) # augment segments (flip left-right)
|
||||
@ -543,7 +538,7 @@ class CopyPaste:
|
||||
im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
|
||||
labels["img"] = im
|
||||
labels["cls"] = cls
|
||||
labels["instances"].update(bboxes, segments, keypoints)
|
||||
labels["instances"] = instances
|
||||
return labels
|
||||
|
||||
|
||||
|
@ -252,23 +252,36 @@ class Instances:
|
||||
)
|
||||
|
||||
def flipud(self, h):
|
||||
# this function may not be very logical, just for clean code when using augment flipud
|
||||
if self._bboxes.format == "xyxy":
|
||||
y1 = self.bboxes[:, 1].copy()
|
||||
y2 = self.bboxes[:, 3].copy()
|
||||
self.bboxes[:, 1] = h - y2
|
||||
self.bboxes[:, 3] = h - y1
|
||||
else:
|
||||
self.bboxes[:, 1] = h - self.bboxes[:, 1]
|
||||
self.segments[..., 1] = h - self.segments[..., 1]
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 1] = h - self.keypoints[..., 1]
|
||||
|
||||
def fliplr(self, w):
|
||||
# this function may not be very logical, just for clean code when using augment fliplr
|
||||
if self._bboxes.format == "xyxy":
|
||||
x1 = self.bboxes[:, 0].copy()
|
||||
x2 = self.bboxes[:, 2].copy()
|
||||
self.bboxes[:, 0] = w - x2
|
||||
self.bboxes[:, 2] = w - x1
|
||||
else:
|
||||
self.bboxes[:, 0] = w - self.bboxes[:, 0]
|
||||
self.segments[..., 0] = w - self.segments[..., 0]
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] = w - self.keypoints[..., 0]
|
||||
|
||||
def clip(self, w, h):
|
||||
ori_format = self._bboxes.format
|
||||
self.convert_bbox(format="xyxy")
|
||||
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
|
||||
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
|
||||
if ori_format != "xyxy":
|
||||
self.convert_bbox(format=ori_format)
|
||||
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
|
||||
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
|
||||
if self.keypoints is not None:
|
||||
|
@ -14,18 +14,18 @@ def box_area(box):
|
||||
|
||||
def bbox_ioa(box1, box2, eps=1e-7):
|
||||
"""Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
|
||||
box1: np.array of shape(4)
|
||||
box2: np.array of shape(nx4)
|
||||
returns: np.array of shape(n)
|
||||
box1: np.array of shape(nx4)
|
||||
box2: np.array of shape(mx4)
|
||||
returns: np.array of shape(nxm)
|
||||
"""
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
|
||||
|
||||
# Intersection area
|
||||
inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
|
||||
(np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
|
||||
inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * \
|
||||
(np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0)
|
||||
|
||||
# box2 area
|
||||
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
|
||||
|
Loading…
x
Reference in New Issue
Block a user