From 1054819a598f39fcb6d94179f3c4344604c452f2 Mon Sep 17 00:00:00 2001
From: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Date: Wed, 26 Oct 2022 01:21:15 +0530
Subject: [PATCH] Add initial model interface (#30)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
 ultralytics/tests/test_model.py              | 13 +++
 ultralytics/yolo/__init__.py                 |  5 +-
 ultralytics/yolo/data/augment.py             |  2 +-
 ultralytics/yolo/data/utils.py               |  7 +-
 ultralytics/yolo/engine/model.py             | 63 +++++++++++++
 ultralytics/yolo/engine/trainer.py           | 77 ++++++++--------
 ultralytics/yolo/utils/configs/defaults.yaml | 95 ++++++++++----------
 ultralytics/yolo/utils/modeling/tasks.py     | 19 +++-
 ultralytics/yolo/utils/torch_utils.py        |  5 ++
 ultralytics/yolo/v8/classify/__init__.py     |  3 +-
 ultralytics/yolo/v8/classify/train.py        | 36 ++++----
 ultralytics/yolo/v8/classify/val.py          |  4 +-
 12 files changed, 220 insertions(+), 109 deletions(-)
 create mode 100644 ultralytics/tests/test_model.py
 create mode 100644 ultralytics/yolo/engine/model.py

diff --git a/ultralytics/tests/test_model.py b/ultralytics/tests/test_model.py
new file mode 100644
index 00000000..353fab1a
--- /dev/null
+++ b/ultralytics/tests/test_model.py
@@ -0,0 +1,13 @@
+from ultralytics.yolo import YOLO
+
+
+def test_model():
+    model = YOLO()
+    model.new("assets/dummy_model.yaml")
+    model.model = "squeezenet1_0"  # temp solution before get_model is implemented
+    # model.load("yolov5n.pt")
+    model.train(data="imagenette160", epochs=1, lr0=0.01)
+
+
+if __name__ == "__main__":
+    test_model()
diff --git a/ultralytics/yolo/__init__.py b/ultralytics/yolo/__init__.py
index 85f6f6c2..fa1c3b26 100644
--- a/ultralytics/yolo/__init__.py
+++ b/ultralytics/yolo/__init__.py
@@ -1,4 +1,7 @@
+import ultralytics.yolo.v8 as v8
+
+from .engine.model import YOLO
 from .engine.trainer import BaseTrainer
 from .engine.validator import BaseValidator
 
-__all__ = ["BaseTrainer", "BaseValidator"]  # allow simpler import
+__all__ = ["BaseTrainer", "BaseValidator", "YOLO"]  # allow simpler import
diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py
index 553b057b..6c936ad8 100644
--- a/ultralytics/yolo/data/augment.py
+++ b/ultralytics/yolo/data/augment.py
@@ -728,7 +728,7 @@ def classify_albumentations(
                 if vflip > 0:
                     T += [A.VerticalFlip(p=vflip)]
                 if jitter > 0:
-                    color_jitter = (float(jitter),) * 3  # repeat value for brightness, contrast, satuaration, 0 hue
+                    color_jitter = (float(jitter),) * 3  # repeat value for brightness, contrast, saturation, 0 hue
                     T += [A.ColorJitter(*color_jitter, 0)]
         else:  # Use fixed crop for eval set (reproducibility)
             T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py
index 70b79d40..00277424 100644
--- a/ultralytics/yolo/data/utils.py
+++ b/ultralytics/yolo/data/utils.py
@@ -51,7 +51,8 @@ def exif_size(img):
 def verify_image_label(args):
     # Verify one image-label pair
     im_file, lb_file, prefix, keypoint = args
-    nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None  # number (missing, found, empty, corrupt), message, segments, keypoints
+    # number (missing, found, empty, corrupt), message, segments, keypoints
+    nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", None, None
     try:
         # verify images
         im = Image.open(im_file)
@@ -86,10 +87,10 @@ def verify_image_label(args):
                     kpts = np.zeros((lb.shape[0], 39))
                     for i in range(len(lb)):
                         kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5,
-                                                             3))  # remove the occlusion paramater from the GT
+                                                             3))  # remove the occlusion parameter from the GT
                         kpts[i] = np.hstack((lb[i, :5], kpt))
                     lb = kpts
-                    assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion paramater"
+                    assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter"
                 else:
                     assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
                     assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py
new file mode 100644
index 00000000..838014b7
--- /dev/null
+++ b/ultralytics/yolo/engine/model.py
@@ -0,0 +1,63 @@
+"""
+Top-level YOLO model interface. First principle usage example - https://github.com/ultralytics/ultralytics/issues/13
+"""
+import torch
+import yaml
+
+import ultralytics.yolo as yolo
+from ultralytics.yolo.utils import LOGGER
+from ultralytics.yolo.utils.checks import check_yaml
+from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel
+
+# map head: [model, trainer]
+MODEL_MAP = {
+    "Classify": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'],
+    "Detect": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'],  # temp
+    "Segment": []}
+
+
+class YOLO:
+
+    def __init__(self, version=8) -> None:
+        self.version = version
+        self.model = None
+        self.trainer = None
+        self.pretrained_weights = None
+
+    def new(self, cfg: str):
+        cfg = check_yaml(cfg)  # check YAML
+        self.model, self.trainer = self._get_model_and_trainer(cfg)
+
+    def load(self, weights, autodownload=True):
+        if not isinstance(self.pretrained_weights, type(None)):
+            LOGGER.info("Overwriting weights")
+        # TODO: weights = smart_file_loader(weights)
+        if self.model:
+            self.model.load(weights)
+            LOGGER.info("Checkpoint loaded successfully")
+        else:
+            # TODO: infer model and trainer
+            pass
+
+        self.pretrained_weights = weights
+
+    def reset(self):
+        pass
+
+    def train(self, **kwargs):
+        if 'data' not in kwargs:
+            raise Exception("data is required to train")
+        if not self.model:
+            raise Exception("model not initialized. Use .new() or .load()")
+        kwargs["model"] = self.model
+        trainer = self.trainer(overrides=kwargs)
+        trainer.train()
+
+    def _get_model_and_trainer(self, cfg):
+        with open(cfg, encoding='ascii', errors='ignore') as f:
+            cfg = yaml.safe_load(f)  # model dict
+        model, trainer = MODEL_MAP[cfg["head"][-1][-2]]
+        # warning: eval is unsafe. Use with caution
+        trainer = eval(trainer.replace("VERSION", f"v{self.version}"))
+
+        return model(cfg), trainer
diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py
index 6ba33cde..875bc351 100644
--- a/ultralytics/yolo/engine/trainer.py
+++ b/ultralytics/yolo/engine/trainer.py
@@ -7,7 +7,7 @@ import time
 from collections import defaultdict
 from datetime import datetime
 from pathlib import Path
-from typing import Union
+from typing import Dict, Union
 
 import torch
 import torch.distributed as dist
@@ -29,30 +29,29 @@ DEFAULT_CONFIG = "defaults.yaml"
 
 class BaseTrainer:
 
-    def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
+    def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG, overrides={}):
         self.console = LOGGER
-        self.model, self.data, self.train, self.hyps = self._get_config(config)
+        self.args = self._get_config(config, overrides)
         self.validator = None
         self.callbacks = defaultdict(list)
-        self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}")  # to debug
+        self.console.info(f"Training config: \n args: \n {self.args}")  # to debug
         # Directories
-        self.save_dir = increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok)
+        self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
         self.wdir = self.save_dir / 'weights'
         self.wdir.mkdir(parents=True, exist_ok=True)  # make dir
         self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'
 
         # Save run settings
-        save_yaml(self.save_dir / 'train.yaml', OmegaConf.to_container(self.train, resolve=True))
+        save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True))
 
         # device
-        self.device = utils.torch_utils.select_device(self.train.device, self.train.batch_size)
+        self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size)
         self.console.info(f"running on device {self.device}")
         self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
 
         # Model and Dataloaders.
-        self.trainset, self.testset = self.get_dataset()  # initialize dataset before as nc is needed for model
-        self.model = self.get_model()
-        self.model = self.model.to(self.device)
+        self.trainset, self.testset = self.get_dataset(self.args.data)
+        self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
 
         # epoch level metrics
         self.metrics = {}  # handle metrics returned by validator
@@ -63,18 +62,24 @@ class BaseTrainer:
         for callback, func in loggers.default_callbacks.items():
             self.add_callback(callback, func)
 
-    def _get_config(self, config: Union[str, Path, DictConfig] = None):
+    def _get_config(self, config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):
         """
         Accepts yaml file name or DictConfig containing experiment configuration.
-        Returns train and hyps namespace
+        Returns training args namespace
         :param config: Optional file name or DictConfig object
         """
-        try:
-            if isinstance(config, (str, Path)):
-                config = OmegaConf.load(config)
-            return config.model, config.data, config.train, config.hyps
-        except KeyError as e:
-            raise KeyError("Missing key(s) in config") from e
+        if isinstance(config, (str, Path)):
+            config = OmegaConf.load(config)
+        elif isinstance(config, Dict):
+            config = OmegaConf.create(config)
+
+        # override
+        if isinstance(overrides, str):
+            overrides = OmegaConf.load(overrides)
+        elif isinstance(overrides, Dict):
+            overrides = OmegaConf.create(overrides)
+
+        return OmegaConf.merge(config, overrides)
 
     def add_callback(self, onevent: str, callback):
         """
@@ -92,7 +97,7 @@ class BaseTrainer:
         for callback in self.callbacks.get(onevent, []):
             callback(self)
 
-    def run(self):
+    def train(self):
         world_size = torch.cuda.device_count()
         if world_size > 1:
             mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True)
@@ -109,21 +114,21 @@ class BaseTrainer:
         dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size)
         self.model = self.model.to(self.device)
         self.model = DDP(self.model, device_ids=[rank])
-        self.train.batch_size = self.train.batch_size // world_size
+        self.args.batch_size = self.args.batch_size // world_size
 
     def _setup_train(self, rank):
         """
         Builds dataloaders and optimizer on correct rank process
         """
         self.optimizer = build_optimizer(model=self.model,
-                                         name=self.train.optimizer,
-                                         lr=self.hyps.lr0,
-                                         momentum=self.hyps.momentum,
-                                         decay=self.hyps.weight_decay)
-        self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank)
+                                         name=self.args.optimizer,
+                                         lr=self.args.lr0,
+                                         momentum=self.args.momentum,
+                                         decay=self.args.weight_decay)
+        self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank)
         if rank in {0, -1}:
             print(" Creating testloader rank :", rank)
-            self.test_loader = self.get_dataloader(self.testset, batch_size=self.train.batch_size * 2, rank=rank)
+            self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank)
             self.validator = self.get_validator()
             print("created testloader :", rank)
 
@@ -138,7 +143,7 @@ class BaseTrainer:
         self.epoch_time = None
         self.epoch_time_start = time.time()
         self.train_time_start = time.time()
-        for epoch in range(self.train.epochs):
+        for epoch in range(self.args.epochs):
             # callback hook. on_epoch_start
             self.model.train()
             pbar = enumerate(self.train_loader)
@@ -165,7 +170,7 @@ class BaseTrainer:
                 # log
                 mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0)  # (GB)
                 if rank in {-1, 0}:
-                    pbar.desc = f"{f'{epoch + 1}/{self.train.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
+                    pbar.desc = f"{f'{epoch + 1}/{self.args.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
 
             if rank in [-1, 0]:
                 # validation
@@ -174,7 +179,7 @@ class BaseTrainer:
                 # callback: on_val_end()
 
                 # save model
-                if (not self.train.nosave) or (self.epoch + 1 == self.train.epochs):
+                if (not self.args.nosave) or (self.epoch + 1 == self.args.epochs):
                     self.save_model()
                     # callback; on_model_save
 
@@ -198,7 +203,7 @@ class BaseTrainer:
             'ema': None,  # deepcopy(ema.ema).half(),
             'updates': None,  # ema.updates,
             'optimizer': None,  # optimizer.state_dict(),
-            'train_args': self.train,
+            'train_args': self.args,
             'date': datetime.now().isoformat()}
 
         # Save last, best and delete
@@ -207,22 +212,22 @@ class BaseTrainer:
             torch.save(ckpt, self.best)
         del ckpt
 
-    def get_dataloader(self, path):
+    def get_dataloader(self, dataset_path, batch_size=16, rank=0):
         """
         Returns dataloader derived from torch.data.Dataloader
         """
         pass
 
-    def get_dataset(self):
+    def get_dataset(self, data):
         """
-        Uses self.dataset to download the dataset if needed and verify it.
+        Download the dataset if needed and verify it.
         Returns train and val split datasets
         """
         pass
 
-    def get_model(self):
+    def get_model(self, model, pretrained=True):
         """
-        Uses self.model to load/create/download dataset for any task
+        load/create/download model for any task
         """
         pass
 
@@ -238,7 +243,7 @@ class BaseTrainer:
 
     def preprocess_batch(self, images, labels):
         """
-        Allows custom preprocessing model inputs and ground truths depeding on task type
+        Allows custom preprocessing model inputs and ground truths depending on task type
         """
         return images.to(self.device, non_blocking=True), labels.to(self.device)
 
diff --git a/ultralytics/yolo/utils/configs/defaults.yaml b/ultralytics/yolo/utils/configs/defaults.yaml
index 35b6e023..b2e74741 100644
--- a/ultralytics/yolo/utils/configs/defaults.yaml
+++ b/ultralytics/yolo/utils/configs/defaults.yaml
@@ -1,53 +1,56 @@
 model: null
 data: null
-train:
-  epochs: 300
-  batch_size: 16
-  img_size: 640
-  nosave: False
-  cache: False # True/ram for ram, or disc
-  device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu
-  workers: 8
-  project: "ultralytics-yolo"
-  name: "exp" # TODO: make this informative, maybe exp{#number}_{datetime} ?
-  exist_ok: False
-  pretrained: False
-  optimizer: "Adam" # choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
-  verbose: False
-  seed: 0
-  local_rank: -1
 
-hyps:
-  lr0: 0.001  # initial learning rate (SGD=1E-2, Adam=1E-3)
-  lrf: 0.01  # final OneCycleLR learning rate (lr0 * lrf)
-  momentum: 0.937  # SGD momentum/Adam beta1
-  weight_decay: 0.0005  # optimizer weight decay 5e-4
-  warmup_epochs: 3.0  # warmup epochs (fractions ok)
-  warmup_momentum: 0.8  # warmup initial momentum
-  warmup_bias_lr: 0.1  # warmup initial bias lr
-  box: 0.05  # box loss gain
-  cls: 0.5  # cls loss gain
-  cls_pw: 1.0  # cls BCELoss positive_weight
-  obj: 1.0  # obj loss gain (scale with pixels)
-  obj_pw: 1.0  # obj BCELoss positive_weight
-  iou_t: 0.20  # IoU training threshold
-  anchor_t: 4.0  # anchor-multiple threshold
-  # anchors: 3  # anchors per output layer (0 to ignore)
-  fl_gamma: 0.0  # focal loss gamma (efficientDet default gamma=1.5)
-  hsv_h: 0.015  # image HSV-Hue augmentation (fraction)
-  hsv_s: 0.7  # image HSV-Saturation augmentation (fraction)
-  hsv_v: 0.4  # image HSV-Value augmentation (fraction)
-  degrees: 0.0  # image rotation (+/- deg)
-  translate: 0.1  # image translation (+/- fraction)
-  scale: 0.5  # image scale (+/- gain)
-  shear: 0.0  # image shear (+/- deg)
-  perspective: 0.0  # image perspective (+/- fraction), range 0-0.001
-  flipud: 0.0  # image flip up-down (probability)
-  fliplr: 0.5  # image flip left-right (probability)
-  mosaic: 1.0  # image mosaic (probability)
-  mixup: 0.0  # image mixup (probability)
-  copy_paste: 0.0  # segment copy-paste (probability)
+# Training options
+epochs: 300
+batch_size: 16
+img_size: 640
+nosave: False
+cache: False # True/ram for ram, or disc
+device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu
+workers: 8
+project: "ultralytics-yolo"
+name: "exp" # TODO: make this informative, maybe exp{#number}_{datetime} ?
+exist_ok: False
+pretrained: False
+optimizer: "Adam" # choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
+verbose: False
+seed: 0
+local_rank: -1
+#-----------------------------------#
 
+# Hyper-parameters
+lr0: 0.001  # initial learning rate (SGD=1E-2, Adam=1E-3)
+lrf: 0.01  # final OneCycleLR learning rate (lr0 * lrf)
+momentum: 0.937  # SGD momentum/Adam beta1
+weight_decay: 0.0005  # optimizer weight decay 5e-4
+warmup_epochs: 3.0  # warmup epochs (fractions ok)
+warmup_momentum: 0.8  # warmup initial momentum
+warmup_bias_lr: 0.1  # warmup initial bias lr
+box: 0.05  # box loss gain
+cls: 0.5  # cls loss gain
+cls_pw: 1.0  # cls BCELoss positive_weight
+obj: 1.0  # obj loss gain (scale with pixels)
+obj_pw: 1.0  # obj BCELoss positive_weight
+iou_t: 0.20  # IoU training threshold
+anchor_t: 4.0  # anchor-multiple threshold
+# anchors: 3  # anchors per output layer (0 to ignore)
+fl_gamma: 0.0  # focal loss gamma (efficientDet default gamma=1.5)
+hsv_h: 0.015  # image HSV-Hue augmentation (fraction)
+hsv_s: 0.7  # image HSV-Saturation augmentation (fraction)
+hsv_v: 0.4  # image HSV-Value augmentation (fraction)
+degrees: 0.0  # image rotation (+/- deg)
+translate: 0.1  # image translation (+/- fraction)
+scale: 0.5  # image scale (+/- gain)
+shear: 0.0  # image shear (+/- deg)
+perspective: 0.0  # image perspective (+/- fraction), range 0-0.001
+flipud: 0.0  # image flip up-down (probability)
+fliplr: 0.5  # image flip left-right (probability)
+mosaic: 1.0  # image mosaic (probability)
+mixup: 0.0  # image mixup (probability)
+copy_paste: 0.0  # segment copy-paste (probability)
+
+# Hydra configs -------------------------------------
 # to disable hydra directory creation
 hydra:
   output_subdir: null
diff --git a/ultralytics/yolo/utils/modeling/tasks.py b/ultralytics/yolo/utils/modeling/tasks.py
index 6f2184ed..0fe7e009 100644
--- a/ultralytics/yolo/utils/modeling/tasks.py
+++ b/ultralytics/yolo/utils/modeling/tasks.py
@@ -8,7 +8,8 @@ from ultralytics.yolo.utils import LOGGER
 from ultralytics.yolo.utils.anchors import check_anchor_order
 from ultralytics.yolo.utils.modeling import parse_model
 from ultralytics.yolo.utils.modeling.modules import *
-from ultralytics.yolo.utils.torch_utils import fuse_conv_and_bn, initialize_weights, model_info, scale_img, time_sync
+from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_state_dicts, model_info,
+                                                scale_img, time_sync)
 
 
 class BaseModel(nn.Module):
@@ -67,6 +68,10 @@ class BaseModel(nn.Module):
                 m.anchor_grid = list(map(fn, m.anchor_grid))
         return self
 
+    def load(self, weights):
+        # Force all tasks implement this function
+        raise NotImplementedError("This function needs to be implemented by derived classes!")
+
 
 class DetectionModel(BaseModel):
     # YOLO detection model
@@ -166,6 +171,12 @@ class DetectionModel(BaseModel):
             b.data[:, 5:5 + m.nc] += math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum())  # cls
             mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
 
+    def load(self, weights):
+        ckpt = torch.load(weights, map_location='cpu')  # load checkpoint to CPU to avoid CUDA memory leak
+        csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
+        csd = intersect_state_dicts(csd, self.state_dict())  # intersect
+        self.load_state_dict(csd, strict=False)  # load
+
 
 class SegmentationModel(DetectionModel):
     # YOLOv5 segmentation model
@@ -197,3 +208,9 @@ class ClassificationModel(BaseModel):
     def _from_yaml(self, cfg):
         # Create a YOLOv5 classification model from a *.yaml file
         self.model = None
+
+    def load(self, weights):
+        ckpt = torch.load(weights, map_location='cpu')  # load checkpoint to CPU to avoid CUDA memory leak
+        csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
+        csd = intersect_state_dicts(csd, self.state_dict())  # intersect
+        self.load_state_dict(csd, strict=False)  # load
diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py
index 469c23da..0466810e 100644
--- a/ultralytics/yolo/utils/torch_utils.py
+++ b/ultralytics/yolo/utils/torch_utils.py
@@ -174,3 +174,8 @@ def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
         return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
 
     return decorate
+
+
+def intersect_state_dicts(da, db, exclude=()):
+    # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
+    return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
diff --git a/ultralytics/yolo/v8/classify/__init__.py b/ultralytics/yolo/v8/classify/__init__.py
index 278a980f..23a43a31 100644
--- a/ultralytics/yolo/v8/classify/__init__.py
+++ b/ultralytics/yolo/v8/classify/__init__.py
@@ -1,3 +1,4 @@
-from ultralytics.yolo.v8.classify import train
+from ultralytics.yolo.v8.classify.train import ClassificationTrainer
+from ultralytics.yolo.v8.classify.val import ClassificationValidator
 
 __all__ = ["train"]
diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py
index feb05e2b..534dbfd3 100644
--- a/ultralytics/yolo/v8/classify/train.py
+++ b/ultralytics/yolo/v8/classify/train.py
@@ -5,11 +5,10 @@ from pathlib import Path
 import hydra
 import torch
 import torchvision
-from val import ClassificationValidator
 
-from ultralytics.yolo import BaseTrainer, v8
+from ultralytics.yolo import v8
 from ultralytics.yolo.data import build_classification_dataloader
-from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG
+from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG, BaseTrainer
 from ultralytics.yolo.utils.downloads import download
 from ultralytics.yolo.utils.files import WorkingDirectory
 from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first
@@ -18,9 +17,9 @@ from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zer
 # BaseTrainer python usage
 class ClassificationTrainer(BaseTrainer):
 
-    def get_dataset(self):
+    def get_dataset(self, dataset):
         # temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
-        data = Path("datasets") / self.data
+        data = Path("datasets") / dataset
         with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()):
             data_dir = data if data.is_dir() else (Path.cwd() / data)
             if not data_dir.is_dir():
@@ -29,7 +28,7 @@ class ClassificationTrainer(BaseTrainer):
                 if str(data) == 'imagenet':
                     subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
                 else:
-                    url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{self.data}.zip'
+                    url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
                     download(url, dir=data_dir.parent)
                 # TODO: add colorstr
                 s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n"
@@ -39,17 +38,18 @@ class ClassificationTrainer(BaseTrainer):
 
         return train_set, test_set
 
-    def get_dataloader(self, dataset, batch_size=None, rank=0):
-        return build_classification_dataloader(path=dataset, batch_size=self.train.batch_size, rank=rank)
+    def get_dataloader(self, dataset_path, batch_size=None, rank=0):
+        return build_classification_dataloader(path=dataset_path, batch_size=self.args.batch_size, rank=rank)
 
-    def get_model(self):
+    def get_model(self, model, pretrained):
         # temp. minimal. only supports torchvision models
-        if self.model in torchvision.models.__dict__:  # TorchVision models i.e. resnet50, efficientnet_b0
-            model = torchvision.models.__dict__[self.model](weights='IMAGENET1K_V1' if self.train.pretrained else None)
+        model = self.args.model
+        if model in torchvision.models.__dict__:  # TorchVision models i.e. resnet50, efficientnet_b0
+            model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
         else:
-            raise ModuleNotFoundError(f'--model {self.model} not found.')
+            raise ModuleNotFoundError(f'--model {model} not found.')
         for m in model.modules():
-            if not self.train.pretrained and hasattr(m, 'reset_parameters'):
+            if not pretrained and hasattr(m, 'reset_parameters'):
                 m.reset_parameters()
         for p in model.parameters():
             p.requires_grad = True  # for training
@@ -57,7 +57,7 @@ class ClassificationTrainer(BaseTrainer):
         return model
 
     def get_validator(self):
-        return ClassificationValidator(self.test_loader, self.device, logger=self.console)  # validator
+        return v8.classify.ClassificationValidator(self.test_loader, self.device, logger=self.console)
 
     def criterion(self, preds, targets):
         return torch.nn.functional.cross_entropy(preds, targets)
@@ -66,17 +66,17 @@ class ClassificationTrainer(BaseTrainer):
 @hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0])
 def train(cfg):
     cfg.model = cfg.model or "squeezenet1_0"
-    cfg.data = cfg.data or "imagenette160"  # or yolo.ClassificationDataset("mnist")
+    cfg.data = cfg.data or "imagenette"  # or yolo.ClassificationDataset("mnist")
     trainer = ClassificationTrainer(cfg)
-    trainer.run()
+    trainer.train()
 
 
 if __name__ == "__main__":
     """
     CLI usage:
-    python ../path/to/train.py train.epochs=10 train.project="name" hyps.lr0=0.1
+    python ../path/to/train.py args.epochs=10 args.project="name" hyps.lr0=0.1
 
     TODO:
-    Direct cli support, i.e, yolov8 classify_train train.epochs 10
+    Direct cli support, i.e, yolov8 classify_train args.epochs 10
     """
     train()
diff --git a/ultralytics/yolo/v8/classify/val.py b/ultralytics/yolo/v8/classify/val.py
index 4657ffc9..3d3b4e90 100644
--- a/ultralytics/yolo/v8/classify/val.py
+++ b/ultralytics/yolo/v8/classify/val.py
@@ -1,9 +1,9 @@
 import torch
 
-from ultralytics import yolo
+from ultralytics.yolo.engine.validator import BaseValidator
 
 
-class ClassificationValidator(yolo.BaseValidator):
+class ClassificationValidator(BaseValidator):
 
     def init_metrics(self):
         self.correct = torch.tensor([])