diff --git a/.gitignore b/.gitignore
index b6e47617..75b43690 100644
--- a/.gitignore
+++ b/.gitignore
@@ -127,3 +127,7 @@ dmypy.json
 
 # Pyre type checker
 .pyre/
+
+# datasets and projects
+datasets/
+ultralytics-yolo/
\ No newline at end of file
diff --git a/MANIFEST.in b/MANIFEST.in
index 996eaa28..1635ec15 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,12 +1,5 @@
-# Include the README
 include *.md
 include requirements.txt
-
-# Include the license file
 include LICENSE
-
-# Include setup.py
 include setup.py
-
-# Include the data files
-recursive-include data *
+recursive-include ultralytics *.yaml
diff --git a/requirements.txt b/requirements.txt
index 4077e293..eb7cb418 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,7 @@
 
 # Base ----------------------------------------
 fire>=0.4.0
+hydra-core>=1.2.0
 matplotlib>=3.2.2
 numpy>=1.18.5
 opencv-python>=4.1.1
@@ -44,4 +45,3 @@ thop>=0.1.1  # FLOPs computation
 
 # HUB -----------------------------------------
 GitPython>=3.1.24
-requests
diff --git a/ultralytics/yolo/__init__.py b/ultralytics/yolo/__init__.py
index e69de29b..b307b3b4 100644
--- a/ultralytics/yolo/__init__.py
+++ b/ultralytics/yolo/__init__.py
@@ -0,0 +1,3 @@
+from .engine.trainer import BaseTrainer
+
+__all__ = ["BaseTrainer"]  # allow simpler import
diff --git a/ultralytics/yolo/data/__init__.py b/ultralytics/yolo/data/__init__.py
new file mode 100644
index 00000000..00e903aa
--- /dev/null
+++ b/ultralytics/yolo/data/__init__.py
@@ -0,0 +1,3 @@
+from .build import build_classification_dataloader, build_dataloader
+from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
+from .dataset_wrappers import MixAndRectDataset
diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py
new file mode 100644
index 00000000..684ab935
--- /dev/null
+++ b/ultralytics/yolo/data/augment.py
@@ -0,0 +1,785 @@
+import collections
+import math
+import random
+from copy import deepcopy
+
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms as T
+
+from ..utils.general import LOGGER, check_version, colorstr, segment2box
+from ..utils.instance import Instances
+from ..utils.metrics import bbox_ioa
+from .utils import IMAGENET_MEAN, IMAGENET_STD, polygons2masks, polygons2masks_overlap
+
+
+# TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
+class BaseTransform:
+
+    def __init__(self) -> None:
+        pass
+
+    def apply_image(self, labels):
+        pass
+
+    def apply_instances(self, labels):
+        pass
+
+    def apply_semantic(self, labels):
+        pass
+
+    def __call__(self, labels):
+        self.apply_image(labels)
+        self.apply_instances(labels)
+        self.apply_semantic(labels)
+
+
+class Compose:
+
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, data):
+        for t in self.transforms:
+            data = t(data)
+        return data
+
+    def append(self, transform):
+        self.transforms.append(transform)
+
+    def tolist(self):
+        return self.transforms
+
+    def __repr__(self):
+        format_string = f"{self.__class__.__name__}("
+        for t in self.transforms:
+            format_string += "\n"
+            format_string += f"    {t}"
+        format_string += "\n)"
+        return format_string
+
+
+class BaseMixTransform:
+    """This implementation is from mmyolo"""
+
+    def __init__(self, pre_transform=None, p=0.0) -> None:
+        self.pre_transform = pre_transform
+        self.p = p
+
+    def __call__(self, labels):
+        if random.uniform(0, 1) > self.p:
+            return labels
+
+        assert "dataset" in labels
+        dataset = labels.pop("dataset")
+
+        # get index of one or three other images
+        indexes = self.get_indexes(dataset)
+        if not isinstance(indexes, collections.abc.Sequence):
+            indexes = [indexes]
+
+        # get images information will be used for Mosaic or MixUp
+        mix_labels = [deepcopy(dataset.get_label_info(index)) for index in indexes]
+
+        if self.pre_transform is not None:
+            for i, data in enumerate(mix_labels):
+                # pre_transform may also require dataset
+                data.update({"dataset": dataset})
+                # before Mosaic or MixUp need to go through
+                # the necessary pre_transform
+                _labels = self.pre_transform(data)
+                _labels.pop("dataset")
+                mix_labels[i] = _labels
+        labels["mix_labels"] = mix_labels
+
+        # Mosaic or MixUp
+        labels = self._mix_transform(labels)
+
+        if "mix_labels" in labels:
+            labels.pop("mix_labels")
+        labels["dataset"] = dataset
+
+        return labels
+
+    def _mix_transform(self, labels):
+        raise NotImplementedError
+
+    def get_indexes(self, dataset):
+        raise NotImplementedError
+
+
+class Mosaic(BaseMixTransform):
+    """Mosaic augmentation.
+    Args:
+        img_size (Sequence[int]): Image size after mosaic pipeline of single
+            image. The shape order should be (height, width).
+            Default to (640, 640).
+    """
+
+    def __init__(self, img_size=640, p=1.0, border=(0, 0)):
+        assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}."
+        super().__init__(pre_transform=None, p=p)
+        self.img_size = img_size
+        self.border = border
+
+    def get_indexes(self, dataset):
+        return [random.randint(0, len(dataset)) for _ in range(3)]
+
+    def _mix_transform(self, labels):
+        mosaic_labels = []
+        assert labels.get("rect_shape", None) is None, "rect and mosaic is exclusive."
+        assert len(labels.get("mix_labels", [])) > 0, "There are no other images for mosaic augment."
+        s = self.img_size
+        yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border)  # mosaic center x, y
+        mix_labels = labels["mix_labels"]
+        for i in range(4):
+            labels_patch = deepcopy(labels) if i == 0 else deepcopy(mix_labels[i - 1])
+            # Load image
+            img = labels_patch["img"]
+            h, w = labels_patch["resized_shape"]
+
+            # place img in img4
+            if i == 0:  # top left
+                img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8)  # base image with 4 tiles
+                x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc  # xmin, ymin, xmax, ymax (large image)
+                x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h  # xmin, ymin, xmax, ymax (small image)
+            elif i == 1:  # top right
+                x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
+                x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
+            elif i == 2:  # bottom left
+                x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
+                x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
+            elif i == 3:  # bottom right
+                x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
+                x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
+
+            img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b]  # img4[ymin:ymax, xmin:xmax]
+            padw = x1a - x1b
+            padh = y1a - y1b
+
+            labels_patch = self._update_labels(labels_patch, padw, padh)
+            mosaic_labels.append(labels_patch)
+        final_labels = self._cat_labels(mosaic_labels)
+        final_labels["img"] = img4
+        return final_labels
+
+    def _update_labels(self, labels, padw, padh):
+        """Update labels"""
+        nh, nw = labels["img"].shape[:2]
+        labels["instances"].convert_bbox(format="xyxy")
+        labels["instances"].denormalize(nw, nh)
+        labels["instances"].add_padding(padw, padh)
+        return labels
+
+    def _cat_labels(self, mosaic_labels):
+        if len(mosaic_labels) == 0:
+            return {}
+        cls = []
+        instances = []
+        for labels in mosaic_labels:
+            cls.append(labels["cls"])
+            instances.append(labels["instances"])
+        final_labels = {
+            "ori_shape": (self.img_size * 2, self.img_size * 2),
+            "resized_shape": (self.img_size * 2, self.img_size * 2),
+            "im_file": mosaic_labels[0]["im_file"],
+            "cls": np.concatenate(cls, 0)}
+
+        final_labels["instances"] = Instances.concatenate(instances, axis=0)
+        final_labels["instances"].clip(self.img_size * 2, self.img_size * 2)
+        return final_labels
+
+
+class MixUp(BaseMixTransform):
+
+    def __init__(self, pre_transform=None, p=0.0) -> None:
+        super().__init__(pre_transform=pre_transform, p=p)
+
+    def get_indexes(self, dataset):
+        return random.randint(0, len(dataset))
+
+    def _mix_transform(self, labels):
+        im = labels["img"]
+        labels2 = labels["mix_labels"][0]
+        im2 = labels2["img"]
+        cls2 = labels2["cls"]
+        # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
+        r = np.random.beta(32.0, 32.0)  # mixup ratio, alpha=beta=32.0
+        im = (im * r + im2 * (1 - r)).astype(np.uint8)
+        cat_instances = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0)
+        cls = labels["cls"]
+        labels["img"] = im
+        labels["instances"] = cat_instances
+        labels["cls"] = np.concatenate([cls, cls2], 0)
+        return labels
+
+
+class RandomPerspective:
+
+    def __init__(self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0)):
+        self.degrees = degrees
+        self.translate = translate
+        self.scale = scale
+        self.shear = shear
+        self.perspective = perspective
+        # mosaic border
+        self.border = border
+
+    def affine_transform(self, img):
+        # Center
+        C = np.eye(3)
+
+        C[0, 2] = -img.shape[1] / 2  # x translation (pixels)
+        C[1, 2] = -img.shape[0] / 2  # y translation (pixels)
+
+        # Perspective
+        P = np.eye(3)
+        P[2, 0] = random.uniform(-self.perspective, self.perspective)  # x perspective (about y)
+        P[2, 1] = random.uniform(-self.perspective, self.perspective)  # y perspective (about x)
+
+        # Rotation and Scale
+        R = np.eye(3)
+        a = random.uniform(-self.degrees, self.degrees)
+        # a += random.choice([-180, -90, 0, 90])  # add 90deg rotations to small rotations
+        s = random.uniform(1 - self.scale, 1 + self.scale)
+        # s = 2 ** random.uniform(-scale, scale)
+        R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
+
+        # Shear
+        S = np.eye(3)
+        S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180)  # x shear (deg)
+        S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180)  # y shear (deg)
+
+        # Translation
+        T = np.eye(3)
+        T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0]  # x translation (pixels)
+        T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1]  # y translation (pixels)
+
+        # Combined rotation matrix
+        M = T @ S @ R @ P @ C  # order of operations (right to left) is IMPORTANT
+        # affine image
+        if (self.border[0] != 0) or (self.border[1] != 0) or (M != np.eye(3)).any():  # image changed
+            if self.perspective:
+                img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114))
+            else:  # affine
+                img = cv2.warpAffine(img, M[:2], dsize=self.size, borderValue=(114, 114, 114))
+        return img, M, s
+
+    def apply_bboxes(self, bboxes, M):
+        """apply affine to bboxes only.
+
+        Args:
+            bboxes(ndarray): list of bboxes, xyxy format, with shape (num_bboxes, 4).
+            M(ndarray): affine matrix.
+        Returns:
+            new_bboxes(ndarray): bboxes after affine, [num_bboxes, 4].
+        """
+        n = len(bboxes)
+        if n == 0:
+            return bboxes
+
+        xy = np.ones((n * 4, 3))
+        xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2)  # x1y1, x2y2, x1y2, x2y1
+        xy = xy @ M.T  # transform
+        xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8)  # perspective rescale or affine
+
+        # create new boxes
+        x = xy[:, [0, 2, 4, 6]]
+        y = xy[:, [1, 3, 5, 7]]
+        return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
+
+    def apply_segments(self, segments, M):
+        """apply affine to segments and generate new bboxes from segments.
+
+        Args:
+            segments(ndarray): list of segments, [num_samples, 500, 2].
+            M(ndarray): affine matrix.
+        Returns:
+            new_segments(ndarray): list of segments after affine, [num_samples, 500, 2].
+            new_bboxes(ndarray): bboxes after affine, [N, 4].
+        """
+        n, num = segments.shape[:2]
+        if n == 0:
+            return [], segments
+
+        xy = np.ones((n * num, 3))
+        segments = segments.reshape(-1, 2)
+        xy[:, :2] = segments
+        xy = xy @ M.T  # transform
+        xy = xy[:, :2] / xy[:, 2:3]
+        segments = xy.reshape(n, -1, 2)
+        bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0)
+        return bboxes, segments
+
+    def apply_keypoints(self, keypoints, M):
+        """apply affine to keypoints.
+
+        Args:
+            keypoints(ndarray): keypoints, [N, 17, 2].
+            M(ndarray): affine matrix.
+        Return:
+            new_keypoints(ndarray): keypoints after affine, [N, 17, 2].
+        """
+        n = len(keypoints)
+        if n == 0:
+            return keypoints
+        new_keypoints = np.ones((n * 17, 3))
+        new_keypoints[:, :2] = keypoints.reshape(n * 17, 2)  # num_kpt is hardcoded to 17
+        new_keypoints = new_keypoints @ M.T  # transform
+        new_keypoints = (new_keypoints[:, :2] / new_keypoints[:, 2:3]).reshape(n, 34)  # perspective rescale or affine
+        new_keypoints[keypoints.reshape(-1, 34) == 0] = 0
+        x_kpts = new_keypoints[:, list(range(0, 34, 2))]
+        y_kpts = new_keypoints[:, list(range(1, 34, 2))]
+
+        x_kpts[np.logical_or.reduce((x_kpts < 0, x_kpts > self.size[0], y_kpts < 0, y_kpts > self.size[1]))] = 0
+        y_kpts[np.logical_or.reduce((x_kpts < 0, x_kpts > self.size[0], y_kpts < 0, y_kpts > self.size[1]))] = 0
+        new_keypoints[:, list(range(0, 34, 2))] = x_kpts
+        new_keypoints[:, list(range(1, 34, 2))] = y_kpts
+        return new_keypoints.reshape(n, 17, 2)
+
+    def __call__(self, labels):
+        """
+        Affine images and targets.
+
+        Args:
+            img(ndarray): image.
+            labels(Dict): a dict of `bboxes`, `segments`, `keypoints`.
+        """
+        img = labels["img"]
+        cls = labels["cls"]
+        instances = labels["instances"]
+        # make sure the coord formats are right
+        instances.convert_bbox(format="xyxy")
+        instances.denormalize(*img.shape[:2][::-1])
+
+        self.size = img.shape[1] + self.border[1] * 2, img.shape[0] + self.border[0] * 2  # w, h
+        # M is affine matrix
+        # scale for func:`box_candidates`
+        img, M, scale = self.affine_transform(img)
+
+        bboxes = self.apply_bboxes(instances.bboxes, M)
+
+        segments = instances.segments
+        keypoints = instances.keypoints
+        # update bboxes if there are segments.
+        if segments is not None:
+            bboxes, segments = self.apply_segments(segments, M)
+
+        if keypoints is not None:
+            keypoints = self.apply_keypoints(keypoints, M)
+        new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False)
+        new_instances.clip(*self.size)
+
+        # filter instances
+        instances.scale(scale_w=scale, scale_h=scale, bbox_only=True)
+        # make the bboxes have the same scale with new_bboxes
+        i = self.box_candidates(box1=instances.bboxes.T,
+                                box2=new_instances.bboxes.T,
+                                area_thr=0.01 if segments is not None else 0.10)
+        labels["instances"] = new_instances[i]
+        # clip
+        labels["cls"] = cls[i]
+        labels["img"] = img
+        return labels
+
+    def box_candidates(self, box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):  # box1(4,n), box2(4,n)
+        # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
+        w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
+        w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
+        ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps))  # aspect ratio
+        return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr)  # candidates
+
+
+class RandomHSV:
+
+    def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5) -> None:
+        self.hgain = hgain
+        self.sgain = sgain
+        self.vgain = vgain
+
+    def __call__(self, labels):
+        img = labels["img"]
+        if self.hgain or self.sgain or self.vgain:
+            r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1  # random gains
+            hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
+            dtype = img.dtype  # uint8
+
+            x = np.arange(0, 256, dtype=r.dtype)
+            lut_hue = ((x * r[0]) % 180).astype(dtype)
+            lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
+            lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
+
+            im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
+            cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img)  # no return needed
+        labels["img"] = img
+        return labels
+
+
+class RandomFlip:
+
+    def __init__(self, p=0.5, direction="horizontal") -> None:
+        assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
+        assert 0 <= p <= 1.0
+
+        self.p = p
+        self.direction = direction
+
+    def __call__(self, labels):
+        img = labels["img"]
+        instances = labels["instances"]
+        instances.convert_bbox(format="xywh")
+        h, w = img.shape[:2]
+        h = 1 if instances.normalized else h
+        w = 1 if instances.normalized else w
+
+        # Flip up-down
+        if self.direction == "vertical" and random.random() < self.p:
+            img = np.flipud(img)
+            img = np.ascontiguousarray(img)
+            instances.flipud(h)
+        if self.direction == "horizontal" and random.random() < self.p:
+            img = np.fliplr(img)
+            img = np.ascontiguousarray(img)
+            instances.fliplr(w)
+        labels["img"] = img
+        labels["instances"] = instances
+        return labels
+
+
+class LetterBox:
+    """Resize image and padding for detection, instance segmentation, pose"""
+
+    def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
+        self.new_shape = new_shape
+        self.auto = auto
+        self.scaleFill = scaleFill
+        self.scaleup = scaleup
+        self.stride = stride
+
+    def __call__(self, labels):
+        img = labels["img"]
+        shape = img.shape[:2]  # current shape [height, width]
+        new_shape = labels.get("rect_shape", self.new_shape)
+        if isinstance(new_shape, int):
+            new_shape = (new_shape, new_shape)
+
+        # Scale ratio (new / old)
+        r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+        if not self.scaleup:  # only scale down, do not scale up (for better val mAP)
+            r = min(r, 1.0)
+
+        # Compute padding
+        ratio = r, r  # width, height ratios
+        new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
+        dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
+        if self.auto:  # minimum rectangle
+            dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride)  # wh padding
+        elif self.scaleFill:  # stretch
+            dw, dh = 0.0, 0.0
+            new_unpad = (new_shape[1], new_shape[0])
+            ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios
+
+        dw /= 2  # divide padding into 2 sides
+        dh /= 2
+
+        if shape[::-1] != new_unpad:  # resize
+            img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
+        top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+        left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+        img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT,
+                                 value=(114, 114, 114))  # add border
+
+        labels = self._update_labels(labels, ratio, dw, dh)
+        labels["img"] = img
+        return labels
+
+    def _update_labels(self, labels, ratio, padw, padh):
+        """Update labels"""
+        labels["instances"].convert_bbox(format="xyxy")
+        labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
+        labels["instances"].scale(*ratio)
+        labels["instances"].add_padding(padw, padh)
+        return labels
+
+
+class CopyPaste:
+
+    def __init__(self, p=0.5) -> None:
+        self.p = p
+
+    def __call__(self, labels):
+        # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
+        im = labels["img"]
+        cls = labels["cls"]
+        bboxes = labels["instances"].bboxes
+        segments = labels["instances"].segments  # n, 1000, 2
+        keypoints = labels["instances"].keypoints
+        if self.p and segments is not None:
+            n = len(segments)
+            h, w, _ = im.shape  # height, width, channels
+            im_new = np.zeros(im.shape, np.uint8)
+            # TODO: this implement can be parallel since segments are ndarray, also might work with Instances inside
+            for j in random.sample(range(n), k=round(self.p * n)):
+                c, b, s = cls[j], bboxes[j], segments[j]
+                box = w - b[2], b[1], w - b[0], b[3]
+                ioa = bbox_ioa(box, bboxes)  # intersection over area
+                if (ioa < 0.30).all():  # allow 30% obscuration of existing labels
+                    bboxes = np.concatenate((bboxes, [box]), 0)
+                    cls = np.concatenate((cls, c[None]), axis=0)
+                    segments = np.concatenate((segments, np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)[None]), 0)
+                    if keypoints is not None:
+                        keypoints = np.concatenate(
+                            (keypoints, np.concatenate((w - keypoints[j][:, 0:1], keypoints[j][:, 1:2]), 1)), 0)
+                    cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED)
+
+            result = cv2.bitwise_and(src1=im, src2=im_new)
+            result = cv2.flip(result, 1)  # augment segments (flip left-right)
+            i = result > 0  # pixels to replace
+            # i[:, :] = result.max(2).reshape(h, w, 1)  # act over ch
+            im[i] = result[i]  # cv2.imwrite('debug.jpg', im)  # debug
+        labels["img"] = im
+        labels["cls"] = cls
+        labels["instances"].update(bboxes, segments, keypoints)
+        return labels
+
+
+class Albumentations:
+    # YOLOv5 Albumentations class (optional, only used if package is installed)
+    def __init__(self, p=1.0):
+        self.p = p
+        self.transform = None
+        prefix = colorstr("albumentations: ")
+        try:
+            import albumentations as A
+
+            check_version(A.__version__, "1.0.3", hard=True)  # version requirement
+
+            T = [
+                A.Blur(p=0.01),
+                A.MedianBlur(p=0.01),
+                A.ToGray(p=0.01),
+                A.CLAHE(p=0.01),
+                A.RandomBrightnessContrast(p=0.0),
+                A.RandomGamma(p=0.0),
+                A.ImageCompression(quality_lower=75, p=0.0),]  # transforms
+            self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
+
+            LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
+        except ImportError:  # package not installed, skip
+            pass
+        except Exception as e:
+            LOGGER.info(f"{prefix}{e}")
+
+    def __call__(self, labels):
+        im = labels["img"]
+        cls = labels["cls"]
+        if len(cls):
+            labels["instances"].convert_bbox("xywh")
+            labels["instances"].normalize(*im.shape[:2][::-1])
+            bboxes = labels["instances"].bboxes
+            # TODO: add supports of segments and keypoints
+            if self.transform and random.random() < self.p:
+                new = self.transform(image=im, bboxes=bboxes, class_labels=cls)  # transformed
+            labels["img"] = new["image"]
+            labels["cls"] = np.array(new["class_labels"])
+            labels["instances"].update(bboxes=bboxes)
+        return labels
+
+
+# TODO: technically this is not an augmentation, maybe we should put this to another files
+class Format:
+
+    def __init__(self, bbox_format="xywh", normalize=True, mask=False, mask_ratio=4, mask_overlap=True, batch_idx=True):
+        self.bbox_format = bbox_format
+        self.normalize = normalize
+        self.mask = mask  # set False when training detection only
+        self.mask_ratio = mask_ratio
+        self.mask_overlap = mask_overlap
+        self.batch_idx = batch_idx  # keep the batch indexes
+
+    def __call__(self, labels):
+        img = labels["img"]
+        h, w = img.shape[:2]
+        cls = labels.pop("cls")
+        instances = labels.pop("instances")
+        instances.convert_bbox(format=self.bbox_format)
+        instances.denormalize(w, h)
+        nl = len(instances)
+
+        if instances.segments is not None and self.mask:
+            masks, instances, cls = self._format_segments(instances, cls, w, h)
+            labels["masks"] = (torch.from_numpy(masks) if nl else torch.zeros(
+                1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio))
+        if self.normalize:
+            instances.normalize(w, h)
+        labels["img"] = self._format_img(img)
+        labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl)
+        labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4))
+        if instances.keypoints is not None:
+            labels["keypoints"] = torch.from_numpy(instances.keypoints) if nl else torch.zeros((nl, 17, 2))
+        # then we can use collate_fn
+        if self.batch_idx:
+            labels["batch_idx"] = torch.zeros(nl)
+        return labels
+
+    def _format_img(self, img):
+        if len(img.shape) < 3:
+            img = np.expand_dims(img, -1)
+        img = np.ascontiguousarray(img.transpose(2, 0, 1))
+        img = torch.from_numpy(img)
+        return img
+
+    def _format_segments(self, instances, cls, w, h):
+        """convert polygon points to bitmap"""
+        segments = instances.segments
+        if self.mask_overlap:
+            masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio)
+            masks = masks[None]  # (640, 640) -> (1, 640, 640)
+            instances = instances[sorted_idx]
+            cls = cls[sorted_idx]
+        else:
+            masks = polygons2masks((h, w), segments, color=1, downsample_ratio=self.mask_ratio)
+
+        return masks, instances, cls
+
+
+def mosaic_transforms(img_size, hyp):
+    pre_transform = Compose([
+        Mosaic(img_size=img_size, p=hyp.mosaic, border=[-img_size // 2, -img_size // 2]),
+        CopyPaste(p=hyp.copy_paste),
+        RandomPerspective(
+            degrees=hyp.degrees,
+            translate=hyp.translate,
+            scale=hyp.scale,
+            shear=hyp.shear,
+            perspective=hyp.perspective,
+            border=[-img_size // 2, -img_size // 2],
+        ),])
+    transforms = Compose([
+        pre_transform,
+        MixUp(
+            pre_transform=pre_transform,
+            p=hyp.mixup,
+        ),
+        Albumentations(p=1.0),
+        RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
+        RandomFlip(direction="vertical", p=hyp.flipud),
+        RandomFlip(direction="horizontal", p=hyp.fliplr),])
+    return transforms
+
+
+def affine_transforms(img_size, hyp):
+    # rect, randomperspective, albumentation, hsv, flipud, fliplr
+    transforms = Compose([
+        LetterBox(new_shape=(img_size, img_size)),
+        RandomPerspective(
+            degrees=hyp.degrees,
+            translate=hyp.translate,
+            scale=hyp.scale,
+            shear=hyp.shear,
+            perspective=hyp.perspective,
+            border=[0, 0],
+        ),
+        Albumentations(p=1.0),
+        RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),
+        RandomFlip(direction="vertical", p=hyp.flipud),
+        RandomFlip(direction="horizontal", p=hyp.fliplr),])
+    return transforms
+
+
+# Classification augmentations -------------------------------------------------------------------------------------------
+def classify_transforms(size=224):
+    # Transforms to apply if albumentations not installed
+    assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"
+    # T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
+    return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
+
+
+def classify_albumentations(
+        augment=True,
+        size=224,
+        scale=(0.08, 1.0),
+        hflip=0.5,
+        vflip=0.0,
+        jitter=0.4,
+        mean=IMAGENET_MEAN,
+        std=IMAGENET_STD,
+        auto_aug=False,
+):
+    # YOLOv5 classification Albumentations (optional, only used if package is installed)
+    prefix = colorstr("albumentations: ")
+    try:
+        import albumentations as A
+        from albumentations.pytorch import ToTensorV2
+
+        check_version(A.__version__, "1.0.3", hard=True)  # version requirement
+        if augment:  # Resize and crop
+            T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
+            if auto_aug:
+                # TODO: implement AugMix, AutoAug & RandAug in albumentation
+                LOGGER.info(f"{prefix}auto augmentations are currently not supported")
+            else:
+                if hflip > 0:
+                    T += [A.HorizontalFlip(p=hflip)]
+                if vflip > 0:
+                    T += [A.VerticalFlip(p=vflip)]
+                if jitter > 0:
+                    color_jitter = (float(jitter),) * 3  # repeat value for brightness, contrast, satuaration, 0 hue
+                    T += [A.ColorJitter(*color_jitter, 0)]
+        else:  # Use fixed crop for eval set (reproducibility)
+            T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
+        T += [A.Normalize(mean=mean, std=std), ToTensorV2()]  # Normalize and convert to Tensor
+        LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
+        return A.Compose(T)
+
+    except ImportError:  # package not installed, skip
+        pass
+    except Exception as e:
+        LOGGER.info(f"{prefix}{e}")
+
+
+class ClassifyLetterBox:
+    # YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
+    def __init__(self, size=(640, 640), auto=False, stride=32):
+        super().__init__()
+        self.h, self.w = (size, size) if isinstance(size, int) else size
+        self.auto = auto  # pass max size integer, automatically solve for short side using stride
+        self.stride = stride  # used with auto
+
+    def __call__(self, im):  # im = np.array HWC
+        imh, imw = im.shape[:2]
+        r = min(self.h / imh, self.w / imw)  # ratio of new/old
+        h, w = round(imh * r), round(imw * r)  # resized image
+        hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
+        top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
+        im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
+        im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
+        return im_out
+
+
+class CenterCrop:
+    # YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
+    def __init__(self, size=640):
+        super().__init__()
+        self.h, self.w = (size, size) if isinstance(size, int) else size
+
+    def __call__(self, im):  # im = np.array HWC
+        imh, imw = im.shape[:2]
+        m = min(imh, imw)  # min dimension
+        top, left = (imh - m) // 2, (imw - m) // 2
+        return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
+
+
+class ToTensor:
+    # YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
+    def __init__(self, half=False):
+        super().__init__()
+        self.half = half
+
+    def __call__(self, im):  # im = np.array HWC in BGR order
+        im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1])  # HWC to CHW -> BGR to RGB -> contiguous
+        im = torch.from_numpy(im)  # to torch
+        im = im.half() if self.half else im.float()  # uint8 to fp16/32
+        im /= 255.0  # 0-255 to 0.0-1.0
+        return im
diff --git a/ultralytics/yolo/data/base.py b/ultralytics/yolo/data/base.py
new file mode 100644
index 00000000..ebd3f17c
--- /dev/null
+++ b/ultralytics/yolo/data/base.py
@@ -0,0 +1,224 @@
+import glob
+import os
+from multiprocessing.pool import ThreadPool
+from pathlib import Path
+from typing import Optional
+
+import cv2
+import numpy as np
+from torch.utils.data import Dataset
+from tqdm import tqdm
+
+from ..utils.general import NUM_THREADS
+from .utils import BAR_FORMAT, HELP_URL, IMG_FORMATS, LOCAL_RANK
+
+
+class BaseDataset(Dataset):
+    """Base Dataset.
+    Args:
+        img_path (str): image path.
+        pipeline (dict): a dict of image transforms.
+        label_path (str): label path, this can also be a ann_file or other custom label path.
+    """
+
+    def __init__(
+        self,
+        img_path,
+        img_size=640,
+        label_path=None,
+        cache=False,
+        augment=True,
+        hyp=None,
+        prefix="",
+        rect=False,
+        batch_size=None,
+        stride=32,
+        pad=0.5,
+        single_cls=False,
+    ):
+        super().__init__()
+        self.img_path = img_path
+        self.img_size = img_size
+        self.label_path = label_path
+        self.augment = augment
+        self.prefix = prefix
+
+        self.im_files = self.get_img_files(self.img_path)
+        self.labels = self.get_labels()
+        if single_cls:
+            self.update_labels(include_class=[], single_cls=single_cls)
+
+        self.ni = len(self.im_files)
+
+        # rect stuff
+        self.rect = rect
+        self.batch_size = batch_size
+        self.stride = stride
+        self.pad = pad
+        if self.rect:
+            assert self.batch_size is not None
+            self.set_rectangle()
+
+        # cache stuff
+        self.ims = [None] * self.ni
+        self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
+        if cache:
+            self.cache_images()
+
+        # transforms
+        self.transforms = self.build_transforms(hyp=hyp)
+
+    def get_img_files(self, img_path):
+        """Read image files."""
+        try:
+            f = []  # image files
+            for p in img_path if isinstance(img_path, list) else [img_path]:
+                p = Path(p)  # os-agnostic
+                if p.is_dir():  # dir
+                    f += glob.glob(str(p / "**" / "*.*"), recursive=True)
+                    # f = list(p.rglob('*.*'))  # pathlib
+                elif p.is_file():  # file
+                    with open(p) as t:
+                        t = t.read().strip().splitlines()
+                        parent = str(p.parent) + os.sep
+                        f += [x.replace("./", parent) if x.startswith("./") else x for x in t]  # local to global path
+                        # f += [p.parent / x.lstrip(os.sep) for x in t]  # local to global path (pathlib)
+                else:
+                    raise FileNotFoundError(f"{self.prefix}{p} does not exist")
+            im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
+            # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS])  # pathlib
+            assert im_files, f"{self.prefix}No images found"
+        except Exception as e:
+            raise Exception(f"{self.prefix}Error loading data from {img_path}: {e}\n{HELP_URL}")
+        return im_files
+
+    def update_labels(self, include_class: Optional[list]):
+        """include_class, filter labels to include only these classes (optional)"""
+        include_class_array = np.array(include_class).reshape(1, -1)
+        for i in range(len(self.labels)):
+            if include_class:
+                cls = self.labels[i]["cls"]
+                bboxes = self.labels[i]["bboxes"]
+                segments = self.labels[i]["segments"]
+                j = (cls == include_class_array).any(1)
+                self.labels[i]["cls"] = cls[j]
+                self.labels[i]["bboxes"] = bboxes[j]
+                if segments:
+                    self.labels[i]["segments"] = segments[j]
+            if self.single_cls:
+                self.labels[i]["cls"] = 0
+
+    def load_image(self, i):
+        # Loads 1 image from dataset index 'i', returns (im, resized hw)
+        im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
+        if im is None:  # not cached in RAM
+            if fn.exists():  # load npy
+                im = np.load(fn)
+            else:  # read image
+                im = cv2.imread(f)  # BGR
+                assert im is not None, f"Image Not Found {f}"
+            h0, w0 = im.shape[:2]  # orig hw
+            r = self.img_size / max(h0, w0)  # ratio
+            if r != 1:  # if sizes are not equal
+                interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
+                im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
+            return im, (h0, w0), im.shape[:2]  # im, hw_original, hw_resized
+        return self.ims[i], self.im_hw0[i], self.im_hw[i]  # im, hw_original, hw_resized
+
+    def cache_images(self):
+        # cache images to memory or disk
+        gb = 0  # Gigabytes of cached images
+        self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
+        fcn = self.cache_images_to_disk if self.cache == "disk" else self.load_image
+        results = ThreadPool(NUM_THREADS).imap(fcn, range(self.ni))
+        pbar = tqdm(enumerate(results), total=self.ni, bar_format=BAR_FORMAT, disable=LOCAL_RANK > 0)
+        for i, x in pbar:
+            if self.cache == "disk":
+                gb += self.npy_files[i].stat().st_size
+            else:  # 'ram'
+                self.ims[i], self.im_hw0[i], self.im_hw[i] = x  # im, hw_orig, hw_resized = load_image(self, i)
+                gb += self.ims[i].nbytes
+            pbar.desc = f"{self.prefix}Caching images ({gb / 1E9:.1f}GB {self.cache})"
+        pbar.close()
+
+    def cache_images_to_disk(self, i):
+        # Saves an image as an *.npy file for faster loading
+        f = self.npy_files[i]
+        if not f.exists():
+            np.save(f.as_posix(), cv2.imread(self.im_files[i]))
+
+    def set_rectangle(self):
+        bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int)  # batch index
+        nb = bi[-1] + 1  # number of batches
+
+        s = np.array([x["shape"] for x in self.labels])  # hw
+        ar = s[:, 0] / s[:, 1]  # aspect ratio
+        irect = ar.argsort()
+        self.im_files = [self.im_files[i] for i in irect]
+        self.labels = [self.labels[i] for i in irect]
+        ar = ar[irect]
+
+        # Set training image shapes
+        shapes = [[1, 1]] * nb
+        for i in range(nb):
+            ari = ar[bi == i]
+            mini, maxi = ari.min(), ari.max()
+            if maxi < 1:
+                shapes[i] = [maxi, 1]
+            elif mini > 1:
+                shapes[i] = [1, 1 / mini]
+
+        self.batch_shapes = np.ceil(np.array(shapes) * self.img_size / self.stride + self.pad).astype(int) * self.stride
+        self.batch = bi  # batch index of image
+
+    def __getitem__(self, index):
+        label = self.get_label_info(index)
+        if self.augment:
+            label["dataset"] = self
+        return self.transforms(label)
+
+    def get_label_info(self, index):
+        label = self.labels[index].copy()
+        img, (h0, w0), (h, w) = self.load_image(index)
+        label["img"] = img
+        label["ori_shape"] = (h0, w0)
+        label["resized_shape"] = (h, w)
+        if self.rect:
+            label["rect_shape"] = self.batch_shapes[self.batch[index]]
+        label = self.update_labels_info(label)
+        return label
+
+    def __len__(self):
+        return len(self.im_files)
+
+    def update_labels_info(self, label):
+        """custom your label format here"""
+        return label
+
+    def build_transforms(self, hyp=None):
+        """Users can custom augmentations here
+        like:
+            if self.augment:
+                # training transforms
+                return Compose([])
+            else:
+                # val transforms
+                return Compose([])
+        """
+        raise NotImplementedError
+
+    def get_labels(self):
+        """Users can custom their own format here.
+        Make sure your output is a list with each element like below:
+            dict(
+                im_file=im_file,
+                shape=shape,  # format: (height, width)
+                cls=cls,
+                bboxes=bboxes, # xywh
+                segments=segments,  # xy
+                keypoints=keypoints, # xy
+                normalized=True, # or False
+                bbox_format="xyxy",  # or xywh, ltwh
+            )
+        """
+        raise NotImplementedError
diff --git a/ultralytics/yolo/data/build.py b/ultralytics/yolo/data/build.py
new file mode 100644
index 00000000..85b20311
--- /dev/null
+++ b/ultralytics/yolo/data/build.py
@@ -0,0 +1,145 @@
+import os
+import random
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader, dataloader, distributed
+
+from ..utils.general import LOGGER
+from ..utils.torch_utils import torch_distributed_zero_first
+from .dataset import ClassificationDataset, YOLODataset
+from .utils import PIN_MEMORY, RANK
+
+
+class InfiniteDataLoader(dataloader.DataLoader):
+    """Dataloader that reuses workers
+
+    Uses same syntax as vanilla DataLoader
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
+        self.iterator = super().__iter__()
+
+    def __len__(self):
+        return len(self.batch_sampler.sampler)
+
+    def __iter__(self):
+        for _ in range(len(self)):
+            yield next(self.iterator)
+
+
+class _RepeatSampler:
+    """Sampler that repeats forever
+
+    Args:
+        sampler (Sampler)
+    """
+
+    def __init__(self, sampler):
+        self.sampler = sampler
+
+    def __iter__(self):
+        while True:
+            yield from iter(self.sampler)
+
+
+def seed_worker(worker_id):
+    # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
+    worker_seed = torch.initial_seed() % 2 ** 32
+    np.random.seed(worker_seed)
+    random.seed(worker_seed)
+
+
+# TODO: we can inject most args from a config file
+def build_dataloader(
+    img_path,
+    img_size,  #
+    batch_size,  #
+    single_cls=False,  #
+    hyp=None,  #
+    augment=False,
+    cache=False,  #
+    image_weights=False,  #
+    stride=32,
+    label_path=None,
+    pad=0.0,
+    rect=False,
+    rank=-1,
+    workers=8,
+    prefix="",
+    shuffle=False,
+    use_segments=False,
+    use_keypoints=False,
+):
+    if rect and shuffle:
+        LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
+        shuffle = False
+    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
+        dataset = YOLODataset(
+            img_path=img_path,
+            img_size=img_size,
+            batch_size=batch_size,
+            label_path=label_path,
+            augment=augment,  # augmentation
+            hyp=hyp,
+            rect=rect,  # rectangular batches
+            cache=cache,
+            single_cls=single_cls,
+            stride=int(stride),
+            pad=pad,
+            prefix=prefix,
+            use_segments=use_segments,
+            use_keypoints=use_keypoints,
+        )
+
+    batch_size = min(batch_size, len(dataset))
+    nd = torch.cuda.device_count()  # number of CUDA devices
+    nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])  # number of workers
+    sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
+    loader = DataLoader if image_weights else InfiniteDataLoader  # only DataLoader allows for attribute updates
+    generator = torch.Generator()
+    generator.manual_seed(6148914691236517205 + RANK)
+    return (
+        loader(
+            dataset=dataset,
+            batch_size=batch_size,
+            shuffle=shuffle and sampler is None,
+            num_workers=nw,
+            sampler=sampler,
+            pin_memory=PIN_MEMORY,
+            collate_fn=getattr(dataset, "collate_fn", None),
+            worker_init_fn=seed_worker,
+            generator=generator,
+        ),
+        dataset,
+    )
+
+
+# build classification
+def build_classification_dataloader(path,
+                                    imgsz=224,
+                                    batch_size=16,
+                                    augment=True,
+                                    cache=False,
+                                    rank=-1,
+                                    workers=8,
+                                    shuffle=True):
+    # Returns Dataloader object to be used with YOLOv5 Classifier
+    with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
+        dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
+    batch_size = min(batch_size, len(dataset))
+    nd = torch.cuda.device_count()
+    nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
+    sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
+    generator = torch.Generator()
+    generator.manual_seed(6148914691236517205 + RANK)
+    return InfiniteDataLoader(dataset,
+                              batch_size=batch_size,
+                              shuffle=shuffle and sampler is None,
+                              num_workers=nw,
+                              sampler=sampler,
+                              pin_memory=PIN_MEMORY,
+                              worker_init_fn=seed_worker,
+                              generator=generator)  # or DataLoader(persistent_workers=True)
diff --git a/ultralytics/yolo/dataloaders/__init__.py b/ultralytics/yolo/data/dataloaders/__init__.py
similarity index 100%
rename from ultralytics/yolo/dataloaders/__init__.py
rename to ultralytics/yolo/data/dataloaders/__init__.py
diff --git a/ultralytics/yolo/dataloaders/box.py b/ultralytics/yolo/data/dataloaders/box.py
similarity index 100%
rename from ultralytics/yolo/dataloaders/box.py
rename to ultralytics/yolo/data/dataloaders/box.py
diff --git a/ultralytics/yolo/dataloaders/segment.py b/ultralytics/yolo/data/dataloaders/segment.py
similarity index 100%
rename from ultralytics/yolo/dataloaders/segment.py
rename to ultralytics/yolo/data/dataloaders/segment.py
diff --git a/ultralytics/yolo/data/dataset.py b/ultralytics/yolo/data/dataset.py
new file mode 100644
index 00000000..3405b8e0
--- /dev/null
+++ b/ultralytics/yolo/data/dataset.py
@@ -0,0 +1,213 @@
+from itertools import repeat
+from multiprocessing.pool import Pool
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+import torchvision
+from tqdm import tqdm
+
+from ..utils.general import LOGGER, NUM_THREADS
+from .augment import *
+from .base import BaseDataset
+from .utils import BAR_FORMAT, HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image_label
+
+
+class YOLODataset(BaseDataset):
+    cache_version = 0.6  # dataset labels *.cache version
+    rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
+    """YOLO Dataset.
+    Args:
+        img_path (str): image path.
+        prefix (str): prefix.
+    """
+
+    def __init__(
+        self,
+        img_path,
+        img_size=640,
+        label_path=None,
+        cache=False,
+        augment=True,
+        hyp=None,
+        prefix="",
+        rect=False,
+        batch_size=None,
+        stride=32,
+        pad=0.0,
+        single_cls=False,
+        use_segments=False,
+        use_keypoints=False,
+    ):
+        self.use_segments = use_segments
+        self.use_keypoints = use_keypoints
+        assert not (self.use_segments and self.use_keypoints), "We can't use both of segmentation and pose."
+        super().__init__(img_path, img_size, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
+                         single_cls)
+
+    def cache_labels(self, path=Path("./labels.cache")):
+        # Cache dataset labels, check images and read shapes
+        x = {"labels": []}
+        nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number missing, found, empty, corrupt, messages
+        desc = f"{self.prefix}Scanning '{path.parent / path.stem}' images and labels..."
+        with Pool(NUM_THREADS) as pool:
+            pbar = tqdm(
+                pool.imap(verify_image_label,
+                          zip(self.im_files, self.label_files, repeat(self.prefix), repeat(self.use_keypoints))),
+                desc=desc,
+                total=len(self.im_files),
+                bar_format=BAR_FORMAT,
+            )
+            for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
+                nm += nm_f
+                nf += nf_f
+                ne += ne_f
+                nc += nc_f
+                if im_file:
+                    x["labels"].append(
+                        dict(
+                            im_file=im_file,
+                            shape=shape,
+                            cls=lb[:, 0:1],  # n, 1
+                            bboxes=lb[:, 1:],  # n, 4
+                            segments=segments,
+                            keypoints=keypoint,
+                            normalized=True,
+                            bbox_format="xywh",
+                        ))
+                if msg:
+                    msgs.append(msg)
+                pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupt"
+
+        pbar.close()
+        if msgs:
+            LOGGER.info("\n".join(msgs))
+        if nf == 0:
+            LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}")
+        x["hash"] = get_hash(self.label_files + self.im_files)
+        x["results"] = nf, nm, ne, nc, len(self.im_files)
+        x["msgs"] = msgs  # warnings
+        x["version"] = self.cache_version  # cache version
+        try:
+            np.save(path, x)  # save cache for next time
+            path.with_suffix(".cache.npy").rename(path)  # remove .npy suffix
+            LOGGER.info(f"{self.prefix}New cache created: {path}")
+        except Exception as e:
+            LOGGER.warning(
+                f"{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}")  # not writeable
+        return x
+
+    def get_labels(self):
+        self.label_files = img2label_paths(self.im_files)
+        cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
+        try:
+            cache, exists = np.load(cache_path, allow_pickle=True).item(), True  # load dict
+            assert cache["version"] == self.cache_version  # matches current version
+            assert cache["hash"] == get_hash(self.label_files + self.im_files)  # identical hash
+        except Exception:
+            cache, exists = self.cache_labels(cache_path), False  # run cache ops
+
+        # Display cache
+        nf, nm, ne, nc, n = cache.pop("results")  # found, missing, empty, corrupt, total
+        if exists and LOCAL_RANK in {-1, 0}:
+            d = f"Scanning '{cache_path}' images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupt"
+            tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=BAR_FORMAT)  # display cache results
+            if cache["msgs"]:
+                LOGGER.info("\n".join(cache["msgs"]))  # display warnings
+        assert nf > 0, f"{self.prefix}No labels found in {cache_path}, can not start training. {HELP_URL}"
+
+        # Read cache
+        [cache.pop(k) for k in ("hash", "version", "msgs")]  # remove items
+        labels = cache["labels"]
+        nl = len(np.concatenate([label["cls"] for label in labels], 0))  # number of labels
+        assert nl > 0, f"{self.prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}"
+        return labels
+
+    # TODO: use hyp config to set all these augmentations
+    def build_transforms(self, hyp=None):
+        mosaic = self.augment and not self.rect
+        # mosaic = False
+        if self.augment:
+            if mosaic:
+                transforms = mosaic_transforms(self.img_size, hyp)
+            else:
+                transforms = affine_transforms(self.img_size, hyp)
+        else:
+            transforms = Compose([LetterBox(new_shape=(self.img_size, self.img_size))])
+        transforms.append(Format(bbox_format="xywh", normalize=True, mask=self.use_segments, batch_idx=True))
+        return transforms
+
+    def update_labels_info(self, label):
+        """custom your label format here"""
+        # NOTE: cls is not with bboxes now, since other tasks like classification and semantic segmentation need a independent cls label
+        # we can make it also support classification and semantic segmentation by add or remove some dict keys there.
+        bboxes = label.pop("bboxes")
+        segments = label.pop("segments", None)
+        keypoints = label.pop("keypoints", None)
+        bbox_format = label.pop("bbox_format")
+        normalized = label.pop("normalized")
+        label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
+        return label
+
+    @staticmethod
+    def collate_fn(batch):
+        # TODO: returning a dict can make thing easier and cleaner when using dataset in training
+        # but I don't know if this will slow down a little bit.
+        new_batch = {}
+        keys = batch[0].keys()
+        values = list(zip(*[list(b.values()) for b in batch]))
+        for i, k in enumerate(keys):
+            value = values[i]
+            if k == "img":
+                value = torch.stack(value, 0)
+            if k in ["mask", "keypoint", "bboxes", "cls"]:
+                value = torch.cat(value, 0)
+            new_batch[k] = values[i]
+        new_batch["batch_idx"] = list(new_batch["batch_idx"])
+        for i in range(len(new_batch["batch_idx"])):
+            new_batch["batch_idx"][i] += i  # add target image index for build_targets()
+        new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
+        return new_batch
+
+
+# Classification dataloaders -------------------------------------------------------------------------------------------
+class ClassificationDataset(torchvision.datasets.ImageFolder):
+    """
+    YOLOv5 Classification Dataset.
+    Arguments
+        root:  Dataset path
+        transform:  torchvision transforms, used by default
+        album_transform: Albumentations transforms, used if installed
+    """
+
+    def __init__(self, root, augment, imgsz, cache=False):
+        super().__init__(root=root)
+        self.torch_transforms = classify_transforms(imgsz)
+        self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
+        self.cache_ram = cache is True or cache == "ram"
+        self.cache_disk = cache == "disk"
+        self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples]  # file, index, npy, im
+
+    def __getitem__(self, i):
+        f, j, fn, im = self.samples[i]  # filename, index, filename.with_suffix('.npy'), image
+        if self.cache_ram and im is None:
+            im = self.samples[i][3] = cv2.imread(f)
+        elif self.cache_disk:
+            if not fn.exists():  # load npy
+                np.save(fn.as_posix(), cv2.imread(f))
+            im = np.load(fn)
+        else:  # read image
+            im = cv2.imread(f)  # BGR
+        if self.album_transforms:
+            sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
+        else:
+            sample = self.torch_transforms(im)
+        return sample, j
+
+
+# TODO: support semantic segmentation
+class SemanticDataset(BaseDataset):
+
+    def __init__(self):
+        pass
diff --git a/ultralytics/yolo/data/dataset_wrappers.py b/ultralytics/yolo/data/dataset_wrappers.py
new file mode 100644
index 00000000..77b8f99a
--- /dev/null
+++ b/ultralytics/yolo/data/dataset_wrappers.py
@@ -0,0 +1,37 @@
+import collections
+from copy import deepcopy
+
+from .augment import LetterBox
+
+
+class MixAndRectDataset:
+    """A wrapper of multiple images mixed dataset.
+
+    Args:
+        dataset (:obj:`BaseDataset`): The dataset to be mixed.
+        transforms (Sequence[dict]): config dict to be composed.
+    """
+
+    def __init__(self, dataset):
+        self.dataset = dataset
+        self.img_size = dataset.img_size
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, index):
+        labels = deepcopy(self.dataset[index])
+        for transform in self.dataset.transforms.tolist():
+            # mosaic and mixup
+            if hasattr(transform, "get_indexes"):
+                indexes = transform.get_indexes(self.dataset)
+                if not isinstance(indexes, collections.abc.Sequence):
+                    indexes = [indexes]
+                mix_labels = [deepcopy(self.dataset[index]) for index in indexes]
+                labels["mix_labels"] = mix_labels
+            if self.dataset.rect and isinstance(transform, LetterBox):
+                transform.new_shape = self.dataset.batch_shapes[self.dataset.batch[index]]
+            labels = transform(labels)
+            if "mix_labels" in labels:
+                labels.pop("mix_labels")
+        return labels
diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py
new file mode 100644
index 00000000..ef3bb9b2
--- /dev/null
+++ b/ultralytics/yolo/data/utils.py
@@ -0,0 +1,177 @@
+import contextlib
+import hashlib
+import os
+
+import cv2
+import numpy as np
+from PIL import ExifTags, Image, ImageOps
+
+from ..utils.general import segments2boxes
+
+HELP_URL = "See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data"
+IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"  # include image suffixes
+VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv"  # include video suffixes
+BAR_FORMAT = "{l_bar}{bar:10}{r_bar}{bar:-10b}"  # tqdm bar format
+LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1))  # https://pytorch.org/docs/stable/elastic/run.html
+RANK = int(os.getenv('RANK', -1))
+PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true"  # global pin_memory for dataloaders
+IMAGENET_MEAN = 0.485, 0.456, 0.406  # RGB mean
+IMAGENET_STD = 0.229, 0.224, 0.225  # RGB standard deviation
+
+# Get orientation exif tag
+for orientation in ExifTags.TAGS.keys():
+    if ExifTags.TAGS[orientation] == "Orientation":
+        break
+
+
+def img2label_paths(img_paths):
+    # Define label paths as a function of image paths
+    sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}"  # /images/, /labels/ substrings
+    return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
+
+
+def get_hash(paths):
+    # Returns a single hash value of a list of paths (files or dirs)
+    size = sum(os.path.getsize(p) for p in paths if os.path.exists(p))  # sizes
+    h = hashlib.md5(str(size).encode())  # hash sizes
+    h.update("".join(paths).encode())  # hash paths
+    return h.hexdigest()  # return hash
+
+
+def exif_size(img):
+    # Returns exif-corrected PIL size
+    s = img.size  # (width, height)
+    with contextlib.suppress(Exception):
+        rotation = dict(img._getexif().items())[orientation]
+        if rotation in [6, 8]:  # rotation 270 or 90
+            s = (s[1], s[0])
+    return s
+
+
+def verify_image_label(args):
+    # Verify one image-label pair
+    im_file, lb_file, prefix, keypoint = args
+    nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None  # number (missing, found, empty, corrupt), message, segments, keypoints
+    try:
+        # verify images
+        im = Image.open(im_file)
+        im.verify()  # PIL verify
+        shape = exif_size(im)  # image size
+        shape = (shape[1], shape[0])  # hw
+        assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
+        assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
+        if im.format.lower() in ("jpg", "jpeg"):
+            with open(im_file, "rb") as f:
+                f.seek(-2, 2)
+                if f.read() != b"\xff\xd9":  # corrupt JPEG
+                    ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
+                    msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved"
+
+        # verify labels
+        if os.path.isfile(lb_file):
+            nf = 1  # label found
+            with open(lb_file) as f:
+                lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
+                if any(len(x) > 6 for x in lb) and (not keypoint):  # is segment
+                    classes = np.array([x[0] for x in lb], dtype=np.float32)
+                    segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb]  # (cls, xy1...)
+                    lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1)  # (cls, xywh)
+                lb = np.array(lb, dtype=np.float32)
+            nl = len(lb)
+            if nl:
+                if keypoint:
+                    assert lb.shape[1] == 56, "labels require 56 columns each"
+                    assert (lb[:, 5::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
+                    assert (lb[:, 6::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
+                    kpts = np.zeros((lb.shape[0], 39))
+                    for i in range(len(lb)):
+                        kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5,
+                                                             3))  # remove the occlusion paramater from the GT
+                        kpts[i] = np.hstack((lb[i, :5], kpt))
+                    lb = kpts
+                    assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion paramater"
+                else:
+                    assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
+                    assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
+                    assert (lb[:, 1:] <=
+                            1).all(), f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
+                _, i = np.unique(lb, axis=0, return_index=True)
+                if len(i) < nl:  # duplicate row check
+                    lb = lb[i]  # remove duplicates
+                    if segments:
+                        segments = [segments[x] for x in i]
+                    msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed"
+            else:
+                ne = 1  # label empty
+                lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
+        else:
+            nm = 1  # label missing
+            lb = np.zeros((0, 39), dtype=np.float32) if keypoint else np.zeros((0, 5), dtype=np.float32)
+        if keypoint:
+            keypoints = lb[:, 5:].reshape(-1, 17, 2)
+        lb = lb[:, :5]
+        return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
+    except Exception as e:
+        nc = 1
+        msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}"
+        return [None, None, None, None, None, nm, nf, ne, nc, msg]
+
+
+def polygon2mask(img_size, polygons, color=1, downsample_ratio=1):
+    """
+    Args:
+        img_size (tuple): The image size.
+        polygons (np.ndarray): [N, M], N is the number of polygons,
+            M is the number of points(Be divided by 2).
+    """
+    mask = np.zeros(img_size, dtype=np.uint8)
+    polygons = np.asarray(polygons)
+    polygons = polygons.astype(np.int32)
+    shape = polygons.shape
+    polygons = polygons.reshape(shape[0], -1, 2)
+    cv2.fillPoly(mask, polygons, color=color)
+    nh, nw = (img_size[0] // downsample_ratio, img_size[1] // downsample_ratio)
+    # NOTE: fillPoly firstly then resize is trying the keep the same way
+    # of loss calculation when mask-ratio=1.
+    mask = cv2.resize(mask, (nw, nh))
+    return mask
+
+
+def polygons2masks(img_size, polygons, color, downsample_ratio=1):
+    """
+    Args:
+        img_size (tuple): The image size.
+        polygons (list[np.ndarray]): each polygon is [N, M],
+            N is the number of polygons,
+            M is the number of points(Be divided by 2).
+    """
+    masks = []
+    for si in range(len(polygons)):
+        mask = polygon2mask(img_size, [polygons[si].reshape(-1)], color, downsample_ratio)
+        masks.append(mask)
+    return np.array(masks)
+
+
+def polygons2masks_overlap(img_size, segments, downsample_ratio=1):
+    """Return a (640, 640) overlap mask."""
+    masks = np.zeros((img_size[0] // downsample_ratio, img_size[1] // downsample_ratio),
+                     dtype=np.int32 if len(segments) > 255 else np.uint8)
+    areas = []
+    ms = []
+    for si in range(len(segments)):
+        mask = polygon2mask(
+            img_size,
+            [segments[si].reshape(-1)],
+            downsample_ratio=downsample_ratio,
+            color=1,
+        )
+        ms.append(mask)
+        areas.append(mask.sum())
+    areas = np.asarray(areas)
+    index = np.argsort(-areas)
+    ms = np.array(ms)[index]
+    for i in range(len(segments)):
+        mask = ms[i] * (i + 1)
+        masks = masks + mask
+        masks = np.clip(masks, a_min=0, a_max=i + 1)
+    return masks, index
diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py
index e69de29b..17b429f7 100644
--- a/ultralytics/yolo/engine/trainer.py
+++ b/ultralytics/yolo/engine/trainer.py
@@ -0,0 +1,325 @@
+"""
+Simple training loop; Boilerplate that could apply to any arbitrary neural network,
+"""
+
+import os
+import time
+from collections import defaultdict
+from datetime import datetime
+from pathlib import Path
+from typing import Union
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn as nn
+from omegaconf import DictConfig, OmegaConf
+from torch.cuda import amp
+from torch.nn.parallel import DistributedDataParallel as DDP
+from tqdm import tqdm
+
+import ultralytics.yolo.utils as utils
+import ultralytics.yolo.utils.loggers as loggers
+from ultralytics.yolo.utils.general import LOGGER, ROOT
+
+CONFIG_PATH_ABS = ROOT / "yolo/utils/configs"
+DEFAULT_CONFIG = "defaults.yaml"
+
+
+class BaseTrainer:
+
+    def __init__(
+            self,
+            model: str,
+            data: str,
+            criterion,  # Should we create our own base loss classes? yolo.losses -> v8.losses.clfLoss
+            validator=None,
+            config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
+        self.console = LOGGER
+        self.model = model
+        self.data = data
+        self.criterion = criterion  # ComputeLoss object TODO: create yolo.Loss classes
+        self.validator = val  # Dummy validator
+        self.callbacks = defaultdict(list)
+        self.train, self.hyps = self._get_config(config)
+        self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}")  # to debug
+        # Directories
+        self.save_dir = utils.increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok)
+        self.wdir = self.save_dir / 'weights'
+        self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
+        self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'
+
+        # Save run settings
+        utils.save_yaml(self.save_dir / 'train.yaml', OmegaConf.to_container(self.train, resolve=True))
+
+        # device
+        self.device = utils.select_device(self.train.device, self.train.batch_size)
+        self.console.info(f"running on device {self.device}")
+        self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
+
+        # Model and Dataloaders. TBD: Should we move this inside trainer?
+        self.trainset, self.testset = self.get_dataset()  # initialize dataset before as nc is needed for model
+        self.model = self.get_model()
+        self.model = self.model.to(self.device)
+
+        # epoch level metrics
+        self.metrics = {}  # handle metrics returned by validator
+        self.best_fitness = None
+        self.fitness = None
+        self.loss = None
+
+        for callback, func in loggers.default_callbacks.items():
+            self.add_callback(callback, func)
+
+    def _get_config(self, config: Union[str, Path, DictConfig] = None):
+        """
+        Accepts yaml file name or DictConfig containing experiment configuration.
+        Returns train and hyps namespace
+        :param config: Optional file name or DictConfig object
+        """
+        try:
+            if isinstance(config, (str, Path)):
+                config = OmegaConf.load(config)
+            return config.train, config.hyps
+        except KeyError as e:
+            raise Exception("Missing key(s) in config") from e
+
+    def add_callback(self, onevent: str, callback):
+        """
+        appends the given callback
+        """
+        self.callbacks[onevent].append(callback)
+
+    def set_callback(self, onevent: str, callback):
+        """
+        overrides the existing callbacks with the given callback
+        """
+        self.callbacks[onevent] = [callback]
+
+    def trigger_callbacks(self, onevent: str):
+        for callback in self.callbacks.get(onevent, []):
+            callback(self)
+
+    def run(self):
+        world_size = torch.cuda.device_count()
+        if world_size > 1:
+            mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
+        else:
+            self._do_train(-1, 1)
+
+    def _setup_ddp(self, rank, world_size):
+        os.environ['MASTER_ADDR'] = 'localhost'
+        os.environ['MASTER_PORT'] = '9020'
+        torch.cuda.set_device(rank)
+        self.device = torch.device('cuda', rank)
+        print(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ")
+
+        dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
+        self.model = self.model.to(self.device)
+        self.model = DDP(self.model, device_ids=[rank])
+        self.train.batch_size = self.train.batch_size // world_size
+
+    def _setup_train(self, rank):
+        """
+        Builds dataloaders and optimizer on correct rank process
+        """
+        self.optimizer = build_optimizer(model=self.model,
+                                         name=self.train.optimizer,
+                                         lr=self.hyps.lr0,
+                                         momentum=self.hyps.momentum,
+                                         decay=self.hyps.weight_decay)
+        self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank)
+        if rank in {0, -1}:
+            print(" Creating testloader rank :", rank)
+            # self.test_loader = self.get_dataloader(self.testset,
+            #                                       batch_size=self.train.batch_size*2,
+            #                                       rank=rank)
+            # print("created testloader :", rank)
+
+    def _do_train(self, rank, world_size):
+        if world_size > 1:
+            self._setup_ddp(rank, world_size)
+
+        # callback hook. before_train
+        self._setup_train(rank)
+
+        self.epoch = 1
+        self.epoch_time = None
+        self.epoch_time_start = time.time()
+        self.train_time_start = time.time()
+        for epoch in range(self.train.epochs):
+            # callback hook. on_epoch_start
+            self.model.train()
+            pbar = enumerate(self.train_loader)
+            if rank in {-1, 0}:
+                pbar = tqdm(enumerate(self.train_loader),
+                            total=len(self.train_loader),
+                            bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
+            tloss = 0
+            for i, (images, labels) in pbar:
+                # callback hook. on_batch_start
+                # forward
+                images, labels = self.preprocess_batch(images, labels)
+                self.loss = self.criterion(self.model(images), labels)
+                tloss = (tloss * i + self.loss.item()) / (i + 1)
+
+                # backward
+                self.model.zero_grad(set_to_none=True)
+                self.scaler.scale(self.loss).backward()
+
+                # optimize
+                self.optimizer_step()
+                self.trigger_callbacks('on_batch_end')
+
+                # log
+                mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0)  # (GB)
+                if rank in {-1, 0}:
+                    pbar.desc = f"{f'{epoch + 1}/{self.train.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
+
+            if rank in [-1, 0]:
+                # validation
+                # callback: on_val_start()
+                self.validate()
+                # callback: on_val_end()
+
+                # save model
+                if (not self.train.nosave) or (self.epoch + 1 == self.train.epochs):
+                    self.save_model()
+                    # callback; on_model_save
+
+            self.epoch += 1
+            tnow = time.time()
+            self.epoch_time = tnow - self.epoch_time_start
+            self.epoch_time_start = tnow
+
+            # TODO: termination condition
+
+        self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours) \
+                            \n{self.usage_help()}")
+        # callback; on_train_end
+        dist.destroy_process_group() if world_size != 1 else None
+
+    def save_model(self):
+        ckpt = {
+            'epoch': self.epoch,
+            'best_fitness': self.best_fitness,
+            'model': None,  # deepcopy(ema.ema).half(),  # deepcopy(de_parallel(model)).half(),
+            'ema': None,  # deepcopy(ema.ema).half(),
+            'updates': None,  # ema.updates,
+            'optimizer': None,  # optimizer.state_dict(),
+            'train_args': self.train,
+            'date': datetime.now().isoformat()}
+
+        # Save last, best and delete
+        torch.save(ckpt, self.last)
+        if self.best_fitness == self.fitness:
+            torch.save(ckpt, self.best)
+        del ckpt
+
+    def get_dataloader(self, path):
+        """
+        Returns dataloader derived from torch.data.Dataloader
+        """
+        pass
+
+    def get_dataset(self):
+        """
+        Uses self.dataset to download the dataset if needed and verify it.
+        Returns train and val split datasets
+        """
+        pass
+
+    def get_model(self):
+        """
+        Uses self.model to load/create/download dataset for any task
+        """
+        pass
+
+    def set_criterion(self, criterion):
+        """
+        :param criterion: yolo.Loss object.
+        """
+        self.criterion = criterion
+
+    def optimizer_step(self):
+        self.scaler.unscale_(self.optimizer)  # unscale gradients
+        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)  # clip gradients
+        self.scaler.step(self.optimizer)
+        self.scaler.update()
+        self.optimizer.zero_grad()
+
+    def preprocess_batch(self, images, labels):
+        """
+        Allows custom preprocessing model inputs and ground truths depeding on task type
+        """
+        return images.to(self.device, non_blocking=True), labels.to(self.device)
+
+    def validate(self):
+        """
+        Runs validation on test set using self.validator.
+        # TODO: discuss validator class. Enforce that a validator metrics dict should contain
+        "fitness" metric.
+        """
+        self.metrics = self.validator(self)
+        self.fitness = self.metrics.get("fitness") or (-self.loss)  # use loss as fitness measure if not found
+        if not self.best_fitness or self.best_fitness < self.fitness:
+            self.best_fitness = self.fitness
+
+    def progress_string(self):
+        """
+        Returns progress string depending on task type.
+        """
+        pass
+
+    def usage_help(self):
+        """
+        Returns usage functionality. gets printed to the console after training.
+        """
+        pass
+
+    def log(self, text, rank=-1):
+        """
+        Logs the given text to given ranks process if provided, otherwise logs to all ranks
+        :param text: text to log
+        :param rank: List[Int]
+
+        """
+        if rank in {-1, 0}:
+            self.console.info(text)
+
+
+def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
+    # TODO: 1. docstring with example? 2. Move this inside Trainer? or utils?
+    # YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
+    g = [], [], []  # optimizer parameter groups
+    bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k)  # normalization layers, i.e. BatchNorm2d()
+    for v in model.modules():
+        if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter):  # bias (no decay)
+            g[2].append(v.bias)
+        if isinstance(v, bn):  # weight (no decay)
+            g[1].append(v.weight)
+        elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter):  # weight (with decay)
+            g[0].append(v.weight)
+
+    if name == 'Adam':
+        optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999))  # adjust beta1 to momentum
+    elif name == 'AdamW':
+        optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
+    elif name == 'RMSProp':
+        optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
+    elif name == 'SGD':
+        optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
+    else:
+        raise NotImplementedError(f'Optimizer {name} not implemented.')
+
+    optimizer.add_param_group({'params': g[0], 'weight_decay': decay})  # add g0 with weight_decay
+    optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0})  # add g1 (BatchNorm2d weights)
+    LOGGER.info(f"optimizer: {type(optimizer).__name__}(lr={lr}) with parameter groups "
+                f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
+    return optimizer
+
+
+# Dummy validator
+def val(trainer: BaseTrainer):
+    trainer.console.info("validating")
+    return {"metric_1": 0.1, "metric_2": 0.2, "fitness": 1}
diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py
index e69de29b..31912d96 100644
--- a/ultralytics/yolo/utils/__init__.py
+++ b/ultralytics/yolo/utils/__init__.py
@@ -0,0 +1,17 @@
+from .general import WorkingDirectory, check_version, download, increment_path, save_yaml
+from .torch_utils import LOCAL_RANK, RANK, WORLD_SIZE, DDP_model, select_device, torch_distributed_zero_first
+
+__all__ = [
+    # general
+    "increment_path",
+    "save_yaml",
+    "WorkingDirectory",
+    "download",
+    "check_version",
+    # torch
+    "torch_distributed_zero_first",
+    "LOCAL_RANK",
+    "RANK",
+    "WORLD_SIZE",
+    "DDP_model",
+    "select_device"]
diff --git a/ultralytics/yolo/utils/configs/__init__.py b/ultralytics/yolo/utils/configs/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/ultralytics/yolo/utils/configs/defaults.yaml b/ultralytics/yolo/utils/configs/defaults.yaml
new file mode 100644
index 00000000..ae715972
--- /dev/null
+++ b/ultralytics/yolo/utils/configs/defaults.yaml
@@ -0,0 +1,53 @@
+train:
+  epochs: 300
+  batch_size: 16
+  img_size: 640
+  nosave: False
+  cache: False # True/ram for ram, or disc
+  device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu
+  workers: 8
+  project: "ultralytics-yolo"
+  name: "exp" # TODO: make this informative, maybe exp{#number}_{datetime} ?
+  exist_ok: False
+  pretrained: False
+  optimizer: "Adam" # choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
+  verbose: False
+  seed: 0
+  local_rank: -1
+
+hyps:
+  lr0: 0.001  # initial learning rate (SGD=1E-2, Adam=1E-3)
+  lrf: 0.01  # final OneCycleLR learning rate (lr0 * lrf)
+  momentum: 0.937  # SGD momentum/Adam beta1
+  weight_decay: 0.0005  # optimizer weight decay 5e-4
+  warmup_epochs: 3.0  # warmup epochs (fractions ok)
+  warmup_momentum: 0.8  # warmup initial momentum
+  warmup_bias_lr: 0.1  # warmup initial bias lr
+  box: 0.05  # box loss gain
+  cls: 0.5  # cls loss gain
+  cls_pw: 1.0  # cls BCELoss positive_weight
+  obj: 1.0  # obj loss gain (scale with pixels)
+  obj_pw: 1.0  # obj BCELoss positive_weight
+  iou_t: 0.20  # IoU training threshold
+  anchor_t: 4.0  # anchor-multiple threshold
+  # anchors: 3  # anchors per output layer (0 to ignore)
+  fl_gamma: 0.0  # focal loss gamma (efficientDet default gamma=1.5)
+  hsv_h: 0.015  # image HSV-Hue augmentation (fraction)
+  hsv_s: 0.7  # image HSV-Saturation augmentation (fraction)
+  hsv_v: 0.4  # image HSV-Value augmentation (fraction)
+  degrees: 0.0  # image rotation (+/- deg)
+  translate: 0.1  # image translation (+/- fraction)
+  scale: 0.5  # image scale (+/- gain)
+  shear: 0.0  # image shear (+/- deg)
+  perspective: 0.0  # image perspective (+/- fraction), range 0-0.001
+  flipud: 0.0  # image flip up-down (probability)
+  fliplr: 0.5  # image flip left-right (probability)
+  mosaic: 1.0  # image mosaic (probability)
+  mixup: 0.0  # image mixup (probability)
+  copy_paste: 0.0  # segment copy-paste (probability)
+
+# to disable hydra directory creation
+hydra:
+  output_subdir: null
+  run:
+    dir: .
diff --git a/ultralytics/yolo/utils/general.py b/ultralytics/yolo/utils/general.py
new file mode 100644
index 00000000..be5ba78b
--- /dev/null
+++ b/ultralytics/yolo/utils/general.py
@@ -0,0 +1,353 @@
+# TODO: Follow google docs format for all functions. Easier for automatic doc parser
+
+import contextlib
+import logging
+import os
+import platform
+import subprocess
+import urllib
+from itertools import repeat
+from multiprocessing.pool import ThreadPool
+from pathlib import Path
+from zipfile import ZipFile
+
+import numpy as np
+import pkg_resources as pkg
+import requests
+import torch
+import yaml
+
+FILE = Path(__file__).resolve()
+ROOT = FILE.parents[2]  # YOLOv5 root directory
+RANK = int(os.getenv('RANK', -1))
+
+# Settings
+DATASETS_DIR = ROOT.parent / 'datasets'  # YOLOv5 datasets directory
+NUM_THREADS = min(8, max(1, os.cpu_count() - 1))  # number of YOLOv5 multiprocessing threads
+AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true'  # global auto-install mode
+VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true'  # global verbose mode
+FONT = 'Arial.ttf'  # https://ultralytics.com/assets/Arial.ttf
+
+
+def is_colab():
+    # Is environment a Google Colab instance?
+    return "COLAB_GPU" in os.environ
+
+
+def is_kaggle():
+    # Is environment a Kaggle Notebook?
+    return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"
+
+
+def emojis(str=""):
+    # Return platform-dependent emoji-safe version of string
+    return str.encode().decode("ascii", "ignore") if platform.system() == "Windows" else str
+
+
+def set_logging(name=None, verbose=VERBOSE):
+    # Sets level and returns logger
+    if is_kaggle() or is_colab():
+        for h in logging.root.handlers:
+            logging.root.removeHandler(h)  # remove all handlers associated with the root logger object
+    rank = int(os.getenv("RANK", -1))  # rank in world for Multi-GPU trainings
+    level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
+    log = logging.getLogger(name)
+    log.setLevel(level)
+    handler = logging.StreamHandler()
+    handler.setFormatter(logging.Formatter("%(message)s"))
+    handler.setLevel(level)
+    log.addHandler(handler)
+
+
+set_logging()  # run before defining LOGGER
+LOGGER = logging.getLogger("yolov5")  # define globally (used in train.py, val.py, detect.py, etc.)
+if platform.system() == "Windows":
+    for fn in LOGGER.info, LOGGER.warning:
+        setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x)))  # emoji safe logging
+
+
+def segment2box(segment, width=640, height=640):
+    # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
+    x, y = segment.T  # segment xy
+    inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
+    x, y, = (
+        x[inside],
+        y[inside],
+    )
+    return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros(4)  # xyxy
+
+
+def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False):
+    # Check version vs. required version
+    current, minimum = (pkg.parse_version(x) for x in (current, minimum))
+    result = (current == minimum) if pinned else (current >= minimum)  # bool
+    s = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed"  # string
+    if hard:
+        assert result, emojis(s)  # assert min requirements met
+    if verbose and not result:
+        LOGGER.warning(s)
+    return result
+
+
+def colorstr(*input):
+    # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e.  colorstr('blue', 'hello world')
+    *args, string = input if len(input) > 1 else ("blue", "bold", input[0])  # color arguments, string
+    colors = {
+        "black": "\033[30m",  # basic colors
+        "red": "\033[31m",
+        "green": "\033[32m",
+        "yellow": "\033[33m",
+        "blue": "\033[34m",
+        "magenta": "\033[35m",
+        "cyan": "\033[36m",
+        "white": "\033[37m",
+        "bright_black": "\033[90m",  # bright colors
+        "bright_red": "\033[91m",
+        "bright_green": "\033[92m",
+        "bright_yellow": "\033[93m",
+        "bright_blue": "\033[94m",
+        "bright_magenta": "\033[95m",
+        "bright_cyan": "\033[96m",
+        "bright_white": "\033[97m",
+        "end": "\033[0m",  # misc
+        "bold": "\033[1m",
+        "underline": "\033[4m",}
+    return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
+
+
+def xyxy2xywh(x):
+    # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
+    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+    y[:, 0] = (x[:, 0] + x[:, 2]) / 2  # x center
+    y[:, 1] = (x[:, 1] + x[:, 3]) / 2  # y center
+    y[:, 2] = x[:, 2] - x[:, 0]  # width
+    y[:, 3] = x[:, 3] - x[:, 1]  # height
+    return y
+
+
+def xywh2xyxy(x):
+    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
+    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
+    y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
+    y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
+    return y
+
+
+def xywh2ltwh(x):
+    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, w, h] where xy1=top-left
+    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
+    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
+    return y
+
+
+def xyxy2ltwh(x):
+    # Convert nx4 boxes from [x1, y1, x2, y2] to [x1, y1, w, h] where xy1=top-left, xy2=bottom-right
+    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+    y[:, 2] = x[:, 2] - x[:, 0]  # width
+    y[:, 3] = x[:, 3] - x[:, 1]  # height
+    return y
+
+
+def ltwh2xywh(x):
+    # Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
+    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+    y[:, 0] = x[:, 0] + x[:, 2] / 2  # center x
+    y[:, 1] = x[:, 1] + x[:, 3] / 2  # center y
+    return y
+
+
+def ltwh2xyxy(x):
+    # Convert nx4 boxes from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+    y[:, 2] = x[:, 2] + x[:, 0]  # width
+    y[:, 3] = x[:, 3] + x[:, 1]  # height
+    return y
+
+
+def segments2boxes(segments):
+    # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
+    boxes = []
+    for s in segments:
+        x, y = s.T  # segment xy
+        boxes.append([x.min(), y.min(), x.max(), y.max()])  # cls, xyxy
+    return xyxy2xywh(np.array(boxes))  # cls, xywh
+
+
+def resample_segments(segments, n=1000):
+    # Up-sample an (n,2) segment
+    for i, s in enumerate(segments):
+        s = np.concatenate((s, s[0:1, :]), axis=0)
+        x = np.linspace(0, len(s) - 1, n)
+        xp = np.arange(len(s))
+        segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T  # segment xy
+    return segments
+
+
+def increment_path(path, exist_ok=False, sep='', mkdir=False):
+    """
+    Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
+    # TODO: docs
+    """
+    path = Path(path)  # os-agnostic
+    if path.exists() and not exist_ok:
+        path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
+
+        # Method 1
+        for n in range(2, 9999):
+            p = f'{path}{sep}{n}{suffix}'  # increment path
+            if not os.path.exists(p):  #
+                break
+        path = Path(p)
+
+    if mkdir:
+        path.mkdir(parents=True, exist_ok=True)  # make directory
+
+    return path
+
+
+def save_yaml(file='data.yaml', data={}):
+    # Single-line safe yaml saving
+    with open(file, 'w') as f:
+        yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
+
+
+def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1, retry=3):
+    # Multithreaded file download and unzip function, used in data.yaml for autodownload
+    def download_one(url, dir):
+        # Download 1 file
+        success = True
+        if Path(url).is_file():
+            f = Path(url)  # filename
+        else:  # does not exist
+            f = dir / Path(url).name
+            LOGGER.info(f'Downloading {url} to {f}...')
+            for i in range(retry + 1):
+                if curl:
+                    s = 'sS' if threads > 1 else ''  # silent
+                    r = os.system(
+                        f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -')  # curl download with retry, continue
+                    success = r == 0
+                else:
+                    torch.hub.download_url_to_file(url, f, progress=threads == 1)  # torch download
+                    success = f.is_file()
+                if success:
+                    break
+                elif i < retry:
+                    LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
+                else:
+                    LOGGER.warning(f'❌ Failed to download {url}...')
+
+        if unzip and success and f.suffix in ('.zip', '.tar', '.gz'):
+            LOGGER.info(f'Unzipping {f}...')
+            if f.suffix == '.zip':
+                ZipFile(f).extractall(path=dir)  # unzip
+            elif f.suffix == '.tar':
+                os.system(f'tar xf {f} --directory {f.parent}')  # unzip
+            elif f.suffix == '.gz':
+                os.system(f'tar xfz {f} --directory {f.parent}')  # unzip
+            if delete:
+                f.unlink()  # remove zip
+
+    dir = Path(dir)
+    dir.mkdir(parents=True, exist_ok=True)  # make directory
+    if threads > 1:
+        pool = ThreadPool(threads)
+        pool.imap(lambda x: download_one(*x), zip(url, repeat(dir)))  # multithreaded
+        pool.close()
+        pool.join()
+    else:
+        for u in [url] if isinstance(url, (str, Path)) else url:
+            download_one(u, dir)
+
+
+class WorkingDirectory(contextlib.ContextDecorator):
+    # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
+    def __init__(self, new_dir):
+        self.dir = new_dir  # new dir
+        self.cwd = Path.cwd().resolve()  # current dir
+
+    def __enter__(self):
+        os.chdir(self.dir)
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        os.chdir(self.cwd)
+
+
+def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
+    # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
+    from utils.general import LOGGER
+
+    file = Path(file)
+    assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
+    try:  # url1
+        LOGGER.info(f'Downloading {url} to {file}...')
+        torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO)
+        assert file.exists() and file.stat().st_size > min_bytes, assert_msg  # check
+    except Exception as e:  # url2
+        if file.exists():
+            file.unlink()  # remove partial downloads
+        LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
+        os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -")  # curl download, retry and resume on fail
+    finally:
+        if not file.exists() or file.stat().st_size < min_bytes:  # check
+            if file.exists():
+                file.unlink()  # remove partial downloads
+            LOGGER.info(f"ERROR: {assert_msg}\n{error_msg}")
+        LOGGER.info('')
+
+
+def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
+    # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
+    from utils.general import LOGGER
+
+    def github_assets(repository, version='latest'):
+        # Return GitHub repo tag and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
+        if version != 'latest':
+            version = f'tags/{version}'  # i.e. tags/v6.2
+        response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json()  # github api
+        return response['tag_name'], [x['name'] for x in response['assets']]  # tag, assets
+
+    file = Path(str(file).strip().replace("'", ''))
+    if not file.exists():
+        # URL specified
+        name = Path(urllib.parse.unquote(str(file))).name  # decode '%2F' to '/' etc.
+        if str(file).startswith(('http:/', 'https:/')):  # download
+            url = str(file).replace(':/', '://')  # Pathlib turns :// -> :/
+            file = name.split('?')[0]  # parse authentication https://url.com/file.txt?auth...
+            if Path(file).is_file():
+                LOGGER.info(f'Found {url} locally at {file}')  # file already exists
+            else:
+                safe_download(file=file, url=url, min_bytes=1E5)
+            return file
+
+        # GitHub assets
+        assets = [f'yolov5{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')]  # default
+        try:
+            tag, assets = github_assets(repo, release)
+        except Exception:
+            try:
+                tag, assets = github_assets(repo)  # latest release
+            except Exception:
+                try:
+                    tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1]
+                except Exception:
+                    tag = release
+
+        file.parent.mkdir(parents=True, exist_ok=True)  # make parent dir (if required)
+        if name in assets:
+            url3 = 'https://drive.google.com/drive/folders/1EFQTEUeXWSFww0luse2jB9M1QNZQGwNl'  # backup gdrive mirror
+            safe_download(
+                file,
+                url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
+                min_bytes=1E5,
+                error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}')
+
+    return str(file)
+
+
+def get_model(model: str):
+    # check for local weights
+    pass
diff --git a/ultralytics/yolo/utils/instance.py b/ultralytics/yolo/utils/instance.py
new file mode 100644
index 00000000..bbf5d6d8
--- /dev/null
+++ b/ultralytics/yolo/utils/instance.py
@@ -0,0 +1,326 @@
+from collections import abc
+from itertools import repeat
+from numbers import Number
+from typing import List
+
+import numpy as np
+
+from .general import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
+
+
+# From PyTorch internals
+def _ntuple(n):
+
+    def parse(x):
+        return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
+
+    return parse
+
+
+to_4tuple = _ntuple(4)
+
+# `xyxy` means left top and right bottom
+# `xywh` means center x, center y and width, height(yolo format)
+# `ltwh` means left top and width, height(coco format)
+_formats = ["xyxy", "xywh", "ltwh"]
+
+__all__ = ["Bboxes"]
+
+
+class Bboxes:
+    """Now only numpy is supported"""
+
+    def __init__(self, bboxes, format="xyxy") -> None:
+        assert format in _formats
+        bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
+        assert bboxes.ndim == 2
+        assert bboxes.shape[1] == 4
+        self.bboxes = bboxes
+        self.format = format
+        # self.normalized = normalized
+
+    # def convert(self, format):
+    #     assert format in _formats
+    #     if self.format == format:
+    #         bboxes = self.bboxes
+    #     elif self.format == "xyxy":
+    #         if format == "xywh":
+    #             bboxes = xyxy2xywh(self.bboxes)
+    #         else:
+    #             bboxes = xyxy2ltwh(self.bboxes)
+    #     elif self.format == "xywh":
+    #         if format == "xyxy":
+    #             bboxes = xywh2xyxy(self.bboxes)
+    #         else:
+    #             bboxes = xywh2ltwh(self.bboxes)
+    #     else:
+    #         if format == "xyxy":
+    #             bboxes = ltwh2xyxy(self.bboxes)
+    #         else:
+    #             bboxes = ltwh2xywh(self.bboxes)
+    #
+    #     return Bboxes(bboxes, format)
+
+    def convert(self, format):
+        assert format in _formats
+        if self.format == format:
+            return
+        elif self.format == "xyxy":
+            bboxes = xyxy2xywh(self.bboxes) if format == "xywh" else xyxy2ltwh(self.bboxes)
+        elif self.format == "xywh":
+            bboxes = xywh2xyxy(self.bboxes) if format == "xyxy" else xywh2ltwh(self.bboxes)
+        else:
+            bboxes = ltwh2xyxy(self.bboxes) if format == "xyxy" else ltwh2xywh(self.bboxes)
+        self.bboxes = bboxes
+        self.format = format
+
+    def areas(self):
+        self.convert("xyxy")
+        return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
+
+    # def denormalize(self, w, h):
+    #     if not self.normalized:
+    #         return
+    #     assert (self.bboxes <= 1.0).all()
+    #     self.bboxes[:, 0::2] *= w
+    #     self.bboxes[:, 1::2] *= h
+    #     self.normalized = False
+    #
+    # def normalize(self, w, h):
+    #     if self.normalized:
+    #         return
+    #     assert (self.bboxes > 1.0).any()
+    #     self.bboxes[:, 0::2] /= w
+    #     self.bboxes[:, 1::2] /= h
+    #     self.normalized = True
+
+    def mul(self, scale):
+        """
+        Args:
+            scale (tuple | List | int): the scale for four coords.
+        """
+        if isinstance(scale, Number):
+            scale = to_4tuple(scale)
+        assert isinstance(scale, (tuple, list))
+        assert len(scale) == 4
+        self.bboxes[:, 0] *= scale[0]
+        self.bboxes[:, 1] *= scale[1]
+        self.bboxes[:, 2] *= scale[2]
+        self.bboxes[:, 3] *= scale[3]
+
+    def add(self, offset):
+        """
+        Args:
+            offset (tuple | List | int): the offset for four coords.
+        """
+        if isinstance(offset, Number):
+            offset = to_4tuple(offset)
+        assert isinstance(offset, (tuple, list))
+        assert len(offset) == 4
+        self.bboxes[:, 0] += offset[0]
+        self.bboxes[:, 1] += offset[1]
+        self.bboxes[:, 2] += offset[2]
+        self.bboxes[:, 3] += offset[3]
+
+    def __len__(self):
+        return len(self.bboxes)
+
+    @classmethod
+    def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes":
+        """
+        Concatenates a list of Boxes into a single Bboxes
+
+        Arguments:
+            boxes_list (list[Bboxes])
+
+        Returns:
+            Bboxes: the concatenated Boxes
+        """
+        assert isinstance(boxes_list, (list, tuple))
+        if not boxes_list:
+            return cls(np.empty(0))
+        assert all(isinstance(box, Bboxes) for box in boxes_list)
+
+        if len(boxes_list) == 1:
+            return boxes_list[0]
+        return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
+
+    def __getitem__(self, index) -> "Bboxes":
+        """
+        Args:
+            index: int, slice, or a BoolArray
+
+        Returns:
+            Bboxes: Create a new :class:`Bboxes` by indexing.
+        """
+        if isinstance(index, int):
+            return Bboxes(self.bboxes[index].view(1, -1))
+        b = self.bboxes[index]
+        assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
+        return Bboxes(b)
+
+
+class Instances:
+
+    def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None:
+        """
+        Args:
+            bboxes (ndarray): bboxes with shape [N, 4].
+            segments (list | ndarray): segments.
+            keypoints (ndarray): keypoints with shape [N, 17, 2].
+        """
+        self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
+        self.keypoints = keypoints
+        self.normalized = normalized
+
+        if isinstance(segments, list) and len(segments) > 0:
+            # list[np.array(1000, 2)] * num_samples
+            segments = resample_segments(segments)
+            # (N, 1000, 2)
+            segments = np.stack(segments, axis=0)
+        self.segments = segments
+
+    def convert_bbox(self, format):
+        self._bboxes.convert(format=format)
+
+    def bbox_areas(self):
+        self._bboxes.areas()
+
+    def scale(self, scale_w, scale_h, bbox_only=False):
+        """this might be similar with denormalize func but without normalized sign"""
+        self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
+        if bbox_only:
+            return
+        if self.segments is not None:
+            self.segments[..., 0] *= scale_w
+            self.segments[..., 1] *= scale_h
+        if self.keypoints is not None:
+            self.keypoints[..., 0] *= scale_w
+            self.keypoints[..., 1] *= scale_h
+
+    def denormalize(self, w, h):
+        if not self.normalized:
+            return
+        self._bboxes.mul(scale=(w, h, w, h))
+        if self.segments is not None:
+            self.segments[..., 0] *= w
+            self.segments[..., 1] *= h
+        if self.keypoints is not None:
+            self.keypoints[..., 0] *= w
+            self.keypoints[..., 1] *= h
+        self.normalized = False
+
+    def normalize(self, w, h):
+        if self.normalized:
+            return
+        self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
+        if self.segments is not None:
+            self.segments[..., 0] /= w
+            self.segments[..., 1] /= h
+        if self.keypoints is not None:
+            self.keypoints[..., 0] /= w
+            self.keypoints[..., 1] /= h
+        self.normalized = True
+
+    def add_padding(self, padw, padh):
+        # handle rect and mosaic situation
+        assert not self.normalized, "you should add padding with absolute coordinates."
+        self._bboxes.add(offset=(padw, padh, padw, padh))
+        if self.segments is not None:
+            self.segments[..., 0] += padw
+            self.segments[..., 1] += padh
+        if self.keypoints is not None:
+            self.keypoints[..., 0] += padw
+            self.keypoints[..., 1] += padh
+
+    def __getitem__(self, index) -> "Instances":
+        """
+        Args:
+            index: int, slice, or a BoolArray
+
+        Returns:
+            Instances: Create a new :class:`Instances` by indexing.
+        """
+        segments = self.segments[index] if self.segments is not None else None
+        keypoints = self.keypoints[index] if self.keypoints is not None else None
+        bboxes = self.bboxes[index]
+        bbox_format = self._bboxes.format
+        return Instances(
+            bboxes=bboxes,
+            segments=segments,
+            keypoints=keypoints,
+            bbox_format=bbox_format,
+            normalized=self.normalized,
+        )
+
+    def flipud(self, h):
+        # this function may not be very logical, just for clean code when using augment flipud
+        self.bboxes[:, 1] = h - self.bboxes[:, 1]
+        if self.segments is not None:
+            self.segments[..., 1] = h - self.segments[..., 1]
+        if self.keypoints is not None:
+            self.keypoints[..., 1] = h - self.keypoints[..., 1]
+
+    def fliplr(self, w):
+        # this function may not be very logical, just for clean code when using augment fliplr
+        self.bboxes[:, 0] = w - self.bboxes[:, 0]
+        if self.segments is not None:
+            self.segments[..., 0] = w - self.segments[..., 0]
+        if self.keypoints is not None:
+            self.keypoints[..., 0] = w - self.keypoints[..., 0]
+
+    def clip(self, w, h):
+        self.convert_bbox(format="xyxy")
+        self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
+        self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
+        if self.segments is not None:
+            self.segments[..., 0] = self.segments[..., 0].clip(0, w)
+            self.segments[..., 1] = self.segments[..., 1].clip(0, h)
+        if self.keypoints is not None:
+            self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
+            self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
+
+    def update(self, bboxes, segments=None, keypoints=None):
+        new_bboxes = Bboxes(bboxes, format=self._bboxes.format)
+        self._bboxes = new_bboxes
+        if segments is not None:
+            self.segments = segments
+        if keypoints is not None:
+            self.keypoints = keypoints
+
+    def __len__(self):
+        return len(self.bboxes)
+
+    @classmethod
+    def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances":
+        """
+        Concatenates a list of Boxes into a single Bboxes
+
+        Arguments:
+            instances_list (list[Bboxes])
+            axis
+
+        Returns:
+            Boxes: the concatenated Boxes
+        """
+        assert isinstance(instances_list, (list, tuple))
+        if not instances_list:
+            return cls(np.empty(0))
+        assert all(isinstance(instance, Instances) for instance in instances_list)
+
+        if len(instances_list) == 1:
+            return instances_list[0]
+
+        use_segment = instances_list[0].segments is not None
+        use_keypoint = instances_list[0].keypoints is not None
+        bbox_format = instances_list[0]._bboxes.format
+        normalized = instances_list[0].normalized
+
+        cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
+        cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis) if use_segment else None
+        cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
+        return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)
+
+    @property
+    def bboxes(self):
+        return self._bboxes.bboxes
diff --git a/ultralytics/yolo/utils/loggers/__init__.py b/ultralytics/yolo/utils/loggers/__init__.py
new file mode 100644
index 00000000..1ff2c17d
--- /dev/null
+++ b/ultralytics/yolo/utils/loggers/__init__.py
@@ -0,0 +1,3 @@
+from .base import default_callbacks
+
+__all__ = ["default_callbacks"]
diff --git a/ultralytics/yolo/utils/loggers/base.py b/ultralytics/yolo/utils/loggers/base.py
new file mode 100644
index 00000000..0c2d855d
--- /dev/null
+++ b/ultralytics/yolo/utils/loggers/base.py
@@ -0,0 +1,32 @@
+def before_train(trainer):
+    # Initialize tensorboard logger
+    pass
+
+
+def on_epoch_start(trainer):
+    pass
+
+
+def on_batch_start(trainer):
+    pass
+
+
+def on_val_start(trainer):
+    pass
+
+
+def on_val_end(trainer):
+    pass
+
+
+def on_model_save(trainer):
+    pass
+
+
+default_callbacks = {
+    "before_train": before_train,
+    "on_epoch_start": on_epoch_start,
+    "on_batch_start": on_batch_start,
+    "on_val_start": on_val_start,
+    "on_val_end": on_val_end,
+    "on_model_save": on_model_save}
diff --git a/ultralytics/yolo/utils/metrics.py b/ultralytics/yolo/utils/metrics.py
new file mode 100644
index 00000000..ea1b3abb
--- /dev/null
+++ b/ultralytics/yolo/utils/metrics.py
@@ -0,0 +1,27 @@
+# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
+"""
+Model validation metrics
+"""
+import numpy as np
+
+
+def bbox_ioa(box1, box2, eps=1e-7):
+    """Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
+    box1:       np.array of shape(4)
+    box2:       np.array of shape(nx4)
+    returns:    np.array of shape(n)
+    """
+
+    # Get the coordinates of bounding boxes
+    b1_x1, b1_y1, b1_x2, b1_y2 = box1
+    b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
+
+    # Intersection area
+    inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
+                 (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
+
+    # box2 area
+    box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
+
+    # Intersection over box2 area
+    return inter_area / box2_area
diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py
new file mode 100644
index 00000000..fcfb8b3d
--- /dev/null
+++ b/ultralytics/yolo/utils/torch_utils.py
@@ -0,0 +1,70 @@
+import os
+from contextlib import contextmanager
+
+import torch
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+from ultralytics.yolo.utils import check_version
+
+LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
+RANK = int(os.getenv('RANK', -1))
+WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
+
+
+@contextmanager
+def torch_distributed_zero_first(local_rank: int):
+    # Decorator to make all processes in distributed training wait for each local_master to do something
+    if local_rank not in [-1, 0]:
+        dist.barrier(device_ids=[local_rank])
+    yield
+    if local_rank == 0:
+        dist.barrier(device_ids=[0])
+
+
+def DDP_model(model):
+    # Model DDP creation with checks
+    assert not check_version(torch.__version__, '1.12.0', pinned=True), \
+        'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
+        'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
+    if check_version(torch.__version__, '1.11.0'):
+        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
+    else:
+        return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
+
+
+def select_device(device='', batch_size=0, newline=True):
+    # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
+    # s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
+    s = f'YOLOv5 🚀 torch-{torch.__version__} '
+    device = str(device).strip().lower().replace('cuda:', '').replace('none', '')  # to string, 'cuda:0' to '0'
+    cpu = device == 'cpu'
+    mps = device == 'mps'  # Apple Metal Performance Shaders (MPS)
+    if cpu or mps:
+        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'  # force torch.cuda.is_available() = False
+    elif device:  # non-cpu device requested
+        os.environ['CUDA_VISIBLE_DEVICES'] = device  # set environment variable - must be before assert is_available()
+        assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
+            f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
+
+    if not cpu and not mps and torch.cuda.is_available():  # prefer GPU if available
+        devices = device.split(',') if device else '0'  # range(torch.cuda.device_count())  # i.e. 0,1,6,7
+        n = len(devices)  # device count
+        if n > 1 and batch_size > 0:  # check batch_size is divisible by device_count
+            assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
+        space = ' ' * (len(s) + 1)
+        for i, d in enumerate(devices):
+            p = torch.cuda.get_device_properties(i)
+            s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n"  # bytes to MB
+        arg = 'cuda:0'
+    elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available():  # prefer MPS if available
+        s += 'MPS\n'
+        arg = 'mps'
+    else:  # revert to CPU
+        s += 'CPU\n'
+        arg = 'cpu'
+
+    if not newline:
+        s = s.rstrip()
+    print(s)
+    return torch.device(arg)
diff --git a/ultralytics/yolo/v8/__init__.py b/ultralytics/yolo/v8/__init__.py
new file mode 100644
index 00000000..6d856d60
--- /dev/null
+++ b/ultralytics/yolo/v8/__init__.py
@@ -0,0 +1,7 @@
+from pathlib import Path
+
+from ultralytics.yolo.v8 import classify
+
+ROOT = Path(__file__).parents[0]  # yolov8 ROOT
+
+__all__ = ["classify"]
diff --git a/ultralytics/yolo/v8/classify/__init__.py b/ultralytics/yolo/v8/classify/__init__.py
new file mode 100644
index 00000000..278a980f
--- /dev/null
+++ b/ultralytics/yolo/v8/classify/__init__.py
@@ -0,0 +1,3 @@
+from ultralytics.yolo.v8.classify import train
+
+__all__ = ["train"]
diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py
new file mode 100644
index 00000000..30ac3b40
--- /dev/null
+++ b/ultralytics/yolo/v8/classify/train.py
@@ -0,0 +1,76 @@
+import subprocess
+import time
+from pathlib import Path
+
+import hydra
+import torch
+import torch.hub as hub
+import torchvision
+import torchvision.transforms as T
+from omegaconf import DictConfig, OmegaConf
+
+from ultralytics.yolo import BaseTrainer, utils, v8
+from ultralytics.yolo.data import build_classification_dataloader
+from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG
+
+
+# BaseTrainer python usage
+class Trainer(BaseTrainer):
+
+    def get_dataset(self):
+        # temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
+        data = Path("datasets") / self.data
+        with utils.torch_distributed_zero_first(utils.LOCAL_RANK), utils.WorkingDirectory(Path.cwd()):
+            data_dir = data if data.is_dir() else (Path.cwd() / data)
+            if not data_dir.is_dir():
+                self.console.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
+                t = time.time()
+                if str(data) == 'imagenet':
+                    subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
+                else:
+                    url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{self.data}.zip'
+                    utils.download(url, dir=data_dir.parent)
+                # TODO: add colorstr
+                s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n"
+                self.console.info(s)
+        train_set = data_dir / "train"
+        test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val'  # data/test or data/val
+
+        return train_set, test_set
+
+    def get_dataloader(self, dataset, batch_size=None, rank=0):
+        return build_classification_dataloader(path=dataset, batch_size=self.train.batch_size, rank=rank)
+
+    def get_model(self):
+        # temp. minimal. only supports torchvision models
+        if self.model in torchvision.models.__dict__:  # TorchVision models i.e. resnet50, efficientnet_b0
+            model = torchvision.models.__dict__[self.model](weights='IMAGENET1K_V1' if self.train.pretrained else None)
+        else:
+            raise ModuleNotFoundError(f'--model {self.model} not found.')
+        for m in model.modules():
+            if not self.train.pretrained and hasattr(m, 'reset_parameters'):
+                m.reset_parameters()
+        for p in model.parameters():
+            p.requires_grad = True  # for training
+
+        return model
+
+
+@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0])
+def train(cfg):
+    model = "squeezenet1_0"
+    dataset = "imagenette160"  # or yolo.ClassificationDataset("mnist")
+    criterion = torch.nn.CrossEntropyLoss()  # yolo.Loss object
+    trainer = Trainer(model, dataset, criterion, config=cfg)
+    trainer.run()
+
+
+if __name__ == "__main__":
+    """
+    CLI usage:
+    python ../path/to/train.py train.epochs=10 train.project="name" hyps.lr0=0.1
+
+    TODO:
+    Direct cli support, i.e, yolov8 classify_train train.epochs 10
+    """
+    train()