diff --git a/ultralytics/tests/data/dataloader/hyp_test.yaml b/ultralytics/tests/data/dataloader/hyp_test.yaml
new file mode 100644
index 00000000..a31eef72
--- /dev/null
+++ b/ultralytics/tests/data/dataloader/hyp_test.yaml
@@ -0,0 +1,29 @@
+lr0: 0.001  # initial learning rate (SGD=1E-2, Adam=1E-3)
+lrf: 0.01  # final OneCycleLR learning rate (lr0 * lrf)
+momentum: 0.937  # SGD momentum/Adam beta1
+weight_decay: 0.0005  # optimizer weight decay 5e-4
+warmup_epochs: 3.0  # warmup epochs (fractions ok)
+warmup_momentum: 0.8  # warmup initial momentum
+warmup_bias_lr: 0.1  # warmup initial bias lr
+box: 0.05  # box loss gain
+cls: 0.5  # cls loss gain
+cls_pw: 1.0  # cls BCELoss positive_weight
+obj: 1.0  # obj loss gain (scale with pixels)
+obj_pw: 1.0  # obj BCELoss positive_weight
+iou_t: 0.20  # IoU training threshold
+anchor_t: 4.0  # anchor-multiple threshold
+# anchors: 3  # anchors per output layer (0 to ignore)
+fl_gamma: 0.0  # focal loss gamma (efficientDet default gamma=1.5)
+hsv_h: 0.015  # image HSV-Hue augmentation (fraction)
+hsv_s: 0.7  # image HSV-Saturation augmentation (fraction)
+hsv_v: 0.4  # image HSV-Value augmentation (fraction)
+degrees: 0.0  # image rotation (+/- deg)
+translate: 0.1  # image translation (+/- fraction)
+scale: 0.5  # image scale (+/- gain)
+shear: 0.0  # image shear (+/- deg)
+perspective: 0.0  # image perspective (+/- fraction), range 0-0.001
+flipud: 0.0  # image flip up-down (probability)
+fliplr: 0.5  # image flip left-right (probability)
+mosaic: 1.0  # image mosaic (probability)
+mixup: 0.0  # image mixup (probability)
+copy_paste: 0.0  # segment copy-paste (probability)
diff --git a/ultralytics/tests/data/dataloader/yolodetection.py b/ultralytics/tests/data/dataloader/yolodetection.py
new file mode 100644
index 00000000..7d37a84f
--- /dev/null
+++ b/ultralytics/tests/data/dataloader/yolodetection.py
@@ -0,0 +1,97 @@
+import cv2
+import numpy as np
+from omegaconf import OmegaConf
+
+from ultralytics.yolo.data import build_dataloader
+
+
+class Colors:
+    # Ultralytics color palette https://ultralytics.com/
+    def __init__(self):
+        # hex = matplotlib.colors.TABLEAU_COLORS.values()
+        hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
+                '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
+        self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
+        self.n = len(self.palette)
+
+    def __call__(self, i, bgr=False):
+        c = self.palette[int(i) % self.n]
+        return (c[2], c[1], c[0]) if bgr else c
+
+    @staticmethod
+    def hex2rgb(h):  # rgb order (PIL)
+        return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
+
+
+colors = Colors()  # create instance for 'from utils.plots import colors'
+
+
+def plot_one_box(x, img, color=None, label=None, line_thickness=None):
+    import random
+
+    # Plots one bounding box on image img
+    tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1  # line/font thickness
+    color = color or [random.randint(0, 255) for _ in range(3)]
+    c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
+    cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
+    if label:
+        tf = max(tl - 1, 1)  # font thickness
+        t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
+        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
+        cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filled
+        cv2.putText(
+            img,
+            label,
+            (c1[0], c1[1] - 2),
+            0,
+            tl / 3,
+            [225, 255, 255],
+            thickness=tf,
+            lineType=cv2.LINE_AA,
+        )
+
+
+with open("ultralytics/tests/data/dataloader/hyp_test.yaml") as f:
+    hyp = OmegaConf.load(f)
+
+dataloader, dataset = build_dataloader(
+    img_path="/d/dataset/COCO/coco128-seg/images",
+    img_size=640,
+    label_path=None,
+    cache=False,
+    hyp=hyp,
+    augment=False,
+    prefix="",
+    rect=False,
+    batch_size=4,
+    stride=32,
+    pad=0.5,
+    use_segments=True,
+    use_keypoints=False,
+)
+
+for d in dataloader:
+    idx = 1  # show which image inside one batch
+    img = d["img"][idx].numpy()
+    img = np.ascontiguousarray(img.transpose(1, 2, 0))
+    ih, iw = img.shape[:2]
+    # print(img.shape)
+    bidx = d["batch_idx"]
+    cls = d["cls"][bidx == idx].numpy()
+    bboxes = d["bboxes"][bidx == idx].numpy()
+    print(bboxes.shape)
+    bboxes[:, [0, 2]] *= iw
+    bboxes[:, [1, 3]] *= ih
+    nl = len(cls)
+
+    for i, b in enumerate(bboxes):
+        x, y, w, h = b
+        x1 = x - w / 2
+        x2 = x + w / 2
+        y1 = y - h / 2
+        y2 = y + h / 2
+        c = int(cls[i][0])
+        plot_one_box([int(x1), int(y1), int(x2), int(y2)], img, label=f"{c}", color=colors(c))
+    cv2.imshow("p", img)
+    if cv2.waitKey(0) == ord("q"):
+        break
diff --git a/ultralytics/tests/data/dataloader/yolopose.py b/ultralytics/tests/data/dataloader/yolopose.py
new file mode 100644
index 00000000..e36ed1d1
--- /dev/null
+++ b/ultralytics/tests/data/dataloader/yolopose.py
@@ -0,0 +1,114 @@
+import cv2
+import numpy as np
+import torch
+from omegaconf import OmegaConf
+
+from ultralytics.yolo.data import build_dataloader
+
+
+class Colors:
+    # Ultralytics color palette https://ultralytics.com/
+    def __init__(self):
+        # hex = matplotlib.colors.TABLEAU_COLORS.values()
+        hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
+                '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
+        self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
+        self.n = len(self.palette)
+
+    def __call__(self, i, bgr=False):
+        c = self.palette[int(i) % self.n]
+        return (c[2], c[1], c[0]) if bgr else c
+
+    @staticmethod
+    def hex2rgb(h):  # rgb order (PIL)
+        return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
+
+
+colors = Colors()  # create instance for 'from utils.plots import colors'
+
+
+def plot_one_box(x, img, keypoints=None, color=None, label=None, line_thickness=None):
+    import random
+
+    # Plots one bounding box on image img
+    tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1  # line/font thickness
+    color = color or [random.randint(0, 255) for _ in range(3)]
+    c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
+    cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
+    if label:
+        tf = max(tl - 1, 1)  # font thickness
+        t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
+        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
+        cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filled
+        cv2.putText(
+            img,
+            label,
+            (c1[0], c1[1] - 2),
+            0,
+            tl / 3,
+            [225, 255, 255],
+            thickness=tf,
+            lineType=cv2.LINE_AA,
+        )
+    if keypoints is not None:
+        plot_keypoint(img, keypoints, color, tl)
+
+
+def plot_keypoint(img, keypoints, color, tl):
+    num_l = len(keypoints)
+    # clors = [(255, 0, 0),(0, 255, 0),(0, 0, 255),(255, 255, 0),(0, 255, 255)]
+    # clors = [[random.randint(0, 255) for _ in range(3)] for _ in range(num_l)]
+    for i in range(num_l):
+        point_x = int(keypoints[i][0])
+        point_y = int(keypoints[i][1])
+        cv2.circle(img, (point_x, point_y), tl + 3, color, -1)
+
+
+with open("ultralytics/tests/data/dataloader/hyp_test.yaml") as f:
+    hyp = OmegaConf.load(f)
+
+dataloader, dataset = build_dataloader(
+    img_path="/d/dataset/COCO/images/val2017",
+    img_size=640,
+    label_path=None,
+    cache=False,
+    hyp=hyp,
+    augment=False,
+    prefix="",
+    rect=False,
+    batch_size=4,
+    stride=32,
+    pad=0.5,
+    use_segments=False,
+    use_keypoints=True,
+)
+
+for d in dataloader:
+    idx = 1  # show which image inside one batch
+    img = d["img"][idx].numpy()
+    img = np.ascontiguousarray(img.transpose(1, 2, 0))
+    ih, iw = img.shape[:2]
+    # print(img.shape)
+    bidx = d["batch_idx"]
+    cls = d["cls"][bidx == idx].numpy()
+    bboxes = d["bboxes"][bidx == idx].numpy()
+    bboxes[:, [0, 2]] *= iw
+    bboxes[:, [1, 3]] *= ih
+    keypoints = d["keypoints"][bidx == idx]
+    keypoints[..., 0] *= iw
+    keypoints[..., 1] *= ih
+    # print(keypoints, keypoints.shape)
+    # print(d["im_file"])
+
+    for i, b in enumerate(bboxes):
+        x, y, w, h = b
+        x1 = x - w / 2
+        x2 = x + w / 2
+        y1 = y - h / 2
+        y2 = y + h / 2
+        c = int(cls[i][0])
+        # print(x1, y1, x2, y2)
+        plot_one_box([int(x1), int(y1), int(x2), int(y2)], img, keypoints=keypoints[i], label=f"{c}", color=colors(c))
+    cv2.imshow("p", img)
+    if cv2.waitKey(0) == ord("q"):
+        break
diff --git a/ultralytics/tests/data/dataloader/yolosegment.py b/ultralytics/tests/data/dataloader/yolosegment.py
new file mode 100644
index 00000000..ae99aa5e
--- /dev/null
+++ b/ultralytics/tests/data/dataloader/yolosegment.py
@@ -0,0 +1,112 @@
+import cv2
+import numpy as np
+import torch
+from omegaconf import OmegaConf
+
+from ultralytics.yolo.data import build_dataloader
+
+
+class Colors:
+    # Ultralytics color palette https://ultralytics.com/
+    def __init__(self):
+        # hex = matplotlib.colors.TABLEAU_COLORS.values()
+        hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
+                '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
+        self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
+        self.n = len(self.palette)
+
+    def __call__(self, i, bgr=False):
+        c = self.palette[int(i) % self.n]
+        return (c[2], c[1], c[0]) if bgr else c
+
+    @staticmethod
+    def hex2rgb(h):  # rgb order (PIL)
+        return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
+
+
+colors = Colors()  # create instance for 'from utils.plots import colors'
+
+
+def plot_one_box(x, img, color=None, label=None, line_thickness=None):
+    import random
+
+    # Plots one bounding box on image img
+    tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1  # line/font thickness
+    color = color or [random.randint(0, 255) for _ in range(3)]
+    c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
+    cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
+    if label:
+        tf = max(tl - 1, 1)  # font thickness
+        t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
+        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
+        cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filled
+        cv2.putText(
+            img,
+            label,
+            (c1[0], c1[1] - 2),
+            0,
+            tl / 3,
+            [225, 255, 255],
+            thickness=tf,
+            lineType=cv2.LINE_AA,
+        )
+
+
+with open("ultralytics/tests/data/dataloader/hyp_test.yaml") as f:
+    hyp = OmegaConf.load(f)
+
+dataloader, dataset = build_dataloader(
+    img_path="/d/dataset/COCO/coco128-seg/images",
+    img_size=640,
+    label_path=None,
+    cache=False,
+    hyp=hyp,
+    augment=False,
+    prefix="",
+    rect=False,
+    batch_size=4,
+    stride=32,
+    pad=0.5,
+    use_segments=True,
+    use_keypoints=False,
+)
+
+for d in dataloader:
+    idx = 1  # show which image inside one batch
+    img = d["img"][idx].numpy()
+    img = np.ascontiguousarray(img.transpose(1, 2, 0))
+    ih, iw = img.shape[:2]
+    # print(img.shape)
+    bidx = d["batch_idx"]
+    cls = d["cls"][bidx == idx].numpy()
+    bboxes = d["bboxes"][bidx == idx].numpy()
+    masks = d["masks"][idx]
+    print(bboxes.shape)
+    bboxes[:, [0, 2]] *= iw
+    bboxes[:, [1, 3]] *= ih
+    nl = len(cls)
+
+    index = torch.arange(nl).view(nl, 1, 1) + 1
+    masks = masks.repeat(nl, 1, 1)
+    # print(masks.shape, index.shape)
+    masks = torch.where(masks == index, 1, 0)
+    masks = masks.numpy().astype(np.uint8)
+    print(masks.shape)
+    # keypoints = d["keypoints"]
+
+    for i, b in enumerate(bboxes):
+        x, y, w, h = b
+        x1 = x - w / 2
+        x2 = x + w / 2
+        y1 = y - h / 2
+        y2 = y + h / 2
+        c = int(cls[i][0])
+        # print(x1, y1, x2, y2)
+        plot_one_box([int(x1), int(y1), int(x2), int(y2)], img, label=f"{c}", color=colors(c))
+        mask = masks[i]
+        mask = cv2.resize(mask, (iw, ih))
+        mask = mask.astype(bool)
+        img[mask] = img[mask] * 0.5 + np.array(colors(c)) * 0.5
+    cv2.imshow("p", img)
+    if cv2.waitKey(0) == ord("q"):
+        break
diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py
index 6c936ad8..af752408 100644
--- a/ultralytics/yolo/data/augment.py
+++ b/ultralytics/yolo/data/augment.py
@@ -127,7 +127,7 @@ class Mosaic(BaseMixTransform):
         self.border = border
 
     def get_indexes(self, dataset):
-        return [random.randint(0, len(dataset)) for _ in range(3)]
+        return [random.randint(0, len(dataset) - 1) for _ in range(3)]
 
     def _mix_transform(self, labels):
         mosaic_labels = []
@@ -200,7 +200,7 @@ class MixUp(BaseMixTransform):
         super().__init__(pre_transform=pre_transform, p=p)
 
     def get_indexes(self, dataset):
-        return random.randint(0, len(dataset))
+        return random.randint(0, len(dataset) - 1)
 
     def _mix_transform(self, labels):
         im = labels["img"]
@@ -366,7 +366,7 @@ class RandomPerspective:
         segments = instances.segments
         keypoints = instances.keypoints
         # update bboxes if there are segments.
-        if segments is not None:
+        if len(segments):
             bboxes, segments = self.apply_segments(segments, M)
 
         if keypoints is not None:
@@ -379,7 +379,7 @@ class RandomPerspective:
         # make the bboxes have the same scale with new_bboxes
         i = self.box_candidates(box1=instances.bboxes.T,
                                 box2=new_instances.bboxes.T,
-                                area_thr=0.01 if segments is not None else 0.10)
+                                area_thr=0.01 if len(segments) else 0.10)
         labels["instances"] = new_instances[i]
         # clip
         labels["cls"] = cls[i]
@@ -518,7 +518,7 @@ class CopyPaste:
         bboxes = labels["instances"].bboxes
         segments = labels["instances"].segments  # n, 1000, 2
         keypoints = labels["instances"].keypoints
-        if self.p and segments is not None:
+        if self.p and len(segments):
             n = len(segments)
             h, w, _ = im.shape  # height, width, channels
             im_new = np.zeros(im.shape, np.uint8)
@@ -593,10 +593,18 @@ class Albumentations:
 # TODO: technically this is not an augmentation, maybe we should put this to another files
 class Format:
 
-    def __init__(self, bbox_format="xywh", normalize=True, mask=False, mask_ratio=4, mask_overlap=True, batch_idx=True):
+    def __init__(self,
+                 bbox_format="xywh",
+                 normalize=True,
+                 return_mask=False,
+                 return_keypoint=False,
+                 mask_ratio=4,
+                 mask_overlap=True,
+                 batch_idx=True):
         self.bbox_format = bbox_format
         self.normalize = normalize
-        self.mask = mask  # set False when training detection only
+        self.return_mask = return_mask  # set False when training detection only
+        self.return_keypoint = return_keypoint
         self.mask_ratio = mask_ratio
         self.mask_overlap = mask_overlap
         self.batch_idx = batch_idx  # keep the batch indexes
@@ -610,16 +618,20 @@ class Format:
         instances.denormalize(w, h)
         nl = len(instances)
 
-        if instances.segments is not None and self.mask:
-            masks, instances, cls = self._format_segments(instances, cls, w, h)
-            labels["masks"] = (torch.from_numpy(masks) if nl else torch.zeros(
-                1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio))
+        if self.return_mask:
+            if nl:
+                masks, instances, cls = self._format_segments(instances, cls, w, h)
+                masks = torch.from_numpy(masks)
+            else:
+                masks = torch.zeros(1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio,
+                                    img.shape[1] // self.mask_ratio)
+            labels["masks"] = masks
         if self.normalize:
             instances.normalize(w, h)
         labels["img"] = self._format_img(img)
         labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
         labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
-        if instances.keypoints is not None:
+        if self.return_keypoint:
             labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
         # then we can use collate_fn
         if self.batch_idx:
diff --git a/ultralytics/yolo/data/dataset.py b/ultralytics/yolo/data/dataset.py
index b17fc46f..e332aed1 100644
--- a/ultralytics/yolo/data/dataset.py
+++ b/ultralytics/yolo/data/dataset.py
@@ -132,7 +132,12 @@ class YOLODataset(BaseDataset):
                 transforms = affine_transforms(self.img_size, hyp)
         else:
             transforms = Compose([LetterBox(new_shape=(self.img_size, self.img_size))])
-        transforms.append(Format(bbox_format="xywh", normalize=True, mask=self.use_segments, batch_idx=True))
+        transforms.append(
+            Format(bbox_format="xywh",
+                   normalize=True,
+                   return_mask=self.use_segments,
+                   return_keypoint=self.use_keypoints,
+                   batch_idx=True))
         return transforms
 
     def update_labels_info(self, label):
@@ -140,7 +145,7 @@ class YOLODataset(BaseDataset):
         # NOTE: cls is not with bboxes now, since other tasks like classification and semantic segmentation need a independent cls label
         # we can make it also support classification and semantic segmentation by add or remove some dict keys there.
         bboxes = label.pop("bboxes")
-        segments = label.pop("segments", None)
+        segments = label.pop("segments")
         keypoints = label.pop("keypoints", None)
         bbox_format = label.pop("bbox_format")
         normalized = label.pop("normalized")
@@ -158,9 +163,9 @@ class YOLODataset(BaseDataset):
             value = values[i]
             if k == "img":
                 value = torch.stack(value, 0)
-            if k in ["mask", "keypoint", "bboxes", "cls"]:
+            if k in ["masks", "keypoints", "bboxes", "cls"]:
                 value = torch.cat(value, 0)
-            new_batch[k] = values[i]
+            new_batch[k] = value
         new_batch["batch_idx"] = list(new_batch["batch_idx"])
         for i in range(len(new_batch["batch_idx"])):
             new_batch["batch_idx"][i] += i  # add target image index for build_targets()
diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py
index 00277424..63d59622 100644
--- a/ultralytics/yolo/data/utils.py
+++ b/ultralytics/yolo/data/utils.py
@@ -52,7 +52,7 @@ def verify_image_label(args):
     # Verify one image-label pair
     im_file, lb_file, prefix, keypoint = args
     # number (missing, found, empty, corrupt), message, segments, keypoints
-    nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None
+    nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
     try:
         # verify images
         im = Image.open(im_file)
diff --git a/ultralytics/yolo/utils/instance.py b/ultralytics/yolo/utils/instance.py
index 1481ce34..0c1b7ff5 100644
--- a/ultralytics/yolo/utils/instance.py
+++ b/ultralytics/yolo/utils/instance.py
@@ -162,7 +162,7 @@ class Bboxes:
 
 class Instances:
 
-    def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
+    def __init__(self, bboxes, segments=[], keypoints=None, bbox_format="xywh", normalized=True) -> None:
         """
         Args:
             bboxes (ndarray): bboxes with shape [N, 4].
@@ -173,11 +173,13 @@ class Instances:
         self.keypoints = keypoints
         self.normalized = normalized
 
-        if isinstance(segments, list) and len(segments) > 0:
+        if len(segments) > 0:
             # list[np.array(1000, 2)] * num_samples
             segments = resample_segments(segments)
             # (N, 1000, 2)
             segments = np.stack(segments, axis=0)
+        else:
+            segments = np.zeros((0, 1000, 2), dtype=np.float32)
         self.segments = segments
 
     def convert_bbox(self, format):
@@ -191,9 +193,8 @@ class Instances:
         self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
         if bbox_only:
             return
-        if self.segments is not None:
-            self.segments[..., 0] *= scale_w
-            self.segments[..., 1] *= scale_h
+        self.segments[..., 0] *= scale_w
+        self.segments[..., 1] *= scale_h
         if self.keypoints is not None:
             self.keypoints[..., 0] *= scale_w
             self.keypoints[..., 1] *= scale_h
@@ -202,9 +203,8 @@ class Instances:
         if not self.normalized:
             return
         self._bboxes.mul(scale=(w, h, w, h))
-        if self.segments is not None:
-            self.segments[..., 0] *= w
-            self.segments[..., 1] *= h
+        self.segments[..., 0] *= w
+        self.segments[..., 1] *= h
         if self.keypoints is not None:
             self.keypoints[..., 0] *= w
             self.keypoints[..., 1] *= h
@@ -214,9 +214,8 @@ class Instances:
         if self.normalized:
             return
         self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
-        if self.segments is not None:
-            self.segments[..., 0] /= w
-            self.segments[..., 1] /= h
+        self.segments[..., 0] /= w
+        self.segments[..., 1] /= h
         if self.keypoints is not None:
             self.keypoints[..., 0] /= w
             self.keypoints[..., 1] /= h
@@ -226,9 +225,8 @@ class Instances:
         # handle rect and mosaic situation
         assert not self.normalized, "you should add padding with absolute coordinates."
         self._bboxes.add(offset=(padw, padh, padw, padh))
-        if self.segments is not None:
-            self.segments[..., 0] += padw
-            self.segments[..., 1] += padh
+        self.segments[..., 0] += padw
+        self.segments[..., 1] += padh
         if self.keypoints is not None:
             self.keypoints[..., 0] += padw
             self.keypoints[..., 1] += padh
@@ -241,7 +239,7 @@ class Instances:
         Returns:
             Instances: Create a new :class:`Instances` by indexing.
         """
-        segments = self.segments[index] if self.segments is not None else None
+        segments = self.segments[index] if len(self.segments) else self.segments
         keypoints = self.keypoints[index] if self.keypoints is not None else None
         bboxes = self.bboxes[index]
         bbox_format = self._bboxes.format
@@ -256,16 +254,14 @@ class Instances:
     def flipud(self, h):
         # this function may not be very logical, just for clean code when using augment flipud
         self.bboxes[:, 1] = h - self.bboxes[:, 1]
-        if self.segments is not None:
-            self.segments[..., 1] = h - self.segments[..., 1]
+        self.segments[..., 1] = h - self.segments[..., 1]
         if self.keypoints is not None:
             self.keypoints[..., 1] = h - self.keypoints[..., 1]
 
     def fliplr(self, w):
         # this function may not be very logical, just for clean code when using augment fliplr
         self.bboxes[:, 0] = w - self.bboxes[:, 0]
-        if self.segments is not None:
-            self.segments[..., 0] = w - self.segments[..., 0]
+        self.segments[..., 0] = w - self.segments[..., 0]
         if self.keypoints is not None:
             self.keypoints[..., 0] = w - self.keypoints[..., 0]
 
@@ -273,9 +269,8 @@ class Instances:
         self.convert_bbox(format="xyxy")
         self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
         self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
-        if self.segments is not None:
-            self.segments[..., 0] = self.segments[..., 0].clip(0, w)
-            self.segments[..., 1] = self.segments[..., 1].clip(0, h)
+        self.segments[..., 0] = self.segments[..., 0].clip(0, w)
+        self.segments[..., 1] = self.segments[..., 1].clip(0, h)
         if self.keypoints is not None:
             self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
             self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
@@ -311,13 +306,12 @@ class Instances:
         if len(instances_list) == 1:
             return instances_list[0]
 
-        use_segment = instances_list[0].segments is not None
         use_keypoint = instances_list[0].keypoints is not None
         bbox_format = instances_list[0]._bboxes.format
         normalized = instances_list[0].normalized
 
         cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
-        cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis) if use_segment else None
+        cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
         cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
         return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)