diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 581a7d0a..ae23219e 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -94,7 +94,7 @@ jobs:
       - name: Test segmentation
         shell: bash  # for Windows compatibility
         run: |
-          python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64
+          python ultralytics/yolo/v8/segment/train.py model=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64
       - name: Test classification
         shell: bash  # for Windows compatibility
         run: |
diff --git a/.gitignore b/.gitignore
index 75b43690..ed061573 100644
--- a/.gitignore
+++ b/.gitignore
@@ -130,4 +130,5 @@ dmypy.json
 
 # datasets and projects
 datasets/
-ultralytics-yolo/
\ No newline at end of file
+ultralytics-yolo/
+runs/
\ No newline at end of file
diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py
index 758f2277..c48da2b3 100644
--- a/ultralytics/yolo/engine/trainer.py
+++ b/ultralytics/yolo/engine/trainer.py
@@ -63,10 +63,8 @@ class BaseTrainer:
         else:
             self.data = check_dataset(self.data)
         self.trainset, self.testset = self.get_dataset(self.data)
-        if self.args.cfg is not None:
-            self.model = self.load_cfg(check_file(self.args.cfg))
-        if self.args.model is not None:
-            self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
+        if self.args.model:
+            self.model = self.get_model(self.args.model, self.data)
 
         # epoch level metrics
         self.metrics = {}  # handle metrics returned by validator
@@ -261,20 +259,20 @@ class BaseTrainer:
         """
         return data["train"], data["val"]
 
-    def get_model(self, model, pretrained):
+    def get_model(self, model: str, data: Dict):
         """
         load/create/download model for any task
         """
-        model = get_model(model)
-        for m in model.modules():
-            if not pretrained and hasattr(m, 'reset_parameters'):
-                m.reset_parameters()
-        for p in model.parameters():
-            p.requires_grad = True
-
+        pretrained = False
+        if not str(model).endswith(".yaml"):
+            pretrained = True
+            weights = get_model(model)  # rename this to something less confusing?
+        model = self.load_model(model_cfg=model if not pretrained else None,
+                                weights=weights if pretrained else None,
+                                data=self.data)
         return model
 
-    def load_cfg(self, cfg):
+    def load_model(self, model_cfg, weights, data):
         raise NotImplementedError("This task trainer doesn't support loading cfg files")
 
     def get_validator(self):
diff --git a/ultralytics/yolo/utils/configs/default.yaml b/ultralytics/yolo/utils/configs/default.yaml
index b85c63d3..a1887bdb 100644
--- a/ultralytics/yolo/utils/configs/default.yaml
+++ b/ultralytics/yolo/utils/configs/default.yaml
@@ -3,8 +3,7 @@
 
 
 # Train settings -------------------------------------------------------------------------------------------------------
-model: null  # i.e. yolov5s.pt
-cfg: null  # i.e. yolov5s.yaml
+model: null  # i.e. yolov5s.pt, yolo.yaml
 data: null  # i.e. coco128.yaml
 epochs: 300
 batch_size: 16
@@ -70,6 +69,7 @@ mosaic: 1.0  # image mosaic (probability)
 mixup: 0.0  # image mixup (probability)
 copy_paste: 0.0  # segment copy-paste (probability)
 label_smoothing: 0.0
+# anchors: 3
 
 # Hydra configs --------------------------------------------------------------------------------------------------------
 hydra:
diff --git a/ultralytics/yolo/utils/downloads.py b/ultralytics/yolo/utils/downloads.py
index 1b09a3ee..71fa63de 100644
--- a/ultralytics/yolo/utils/downloads.py
+++ b/ultralytics/yolo/utils/downloads.py
@@ -140,8 +140,3 @@ def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1
     else:
         for u in [url] if isinstance(url, (str, Path)) else url:
             download_one(u, dir)
-
-
-def get_model(model: str):
-    # check for local weights
-    pass
diff --git a/ultralytics/yolo/utils/modeling/tasks.py b/ultralytics/yolo/utils/modeling/tasks.py
index 0cbeb45b..c6c82b54 100644
--- a/ultralytics/yolo/utils/modeling/tasks.py
+++ b/ultralytics/yolo/utils/modeling/tasks.py
@@ -66,7 +66,7 @@ class BaseModel(nn.Module):
         return self
 
     def load(self, weights):
-        # Force all tasks implement this function
+        # Force all tasks to implement this function
         raise NotImplementedError("This function needs to be implemented by derived classes!")
 
 
@@ -169,10 +169,10 @@ class DetectionModel(BaseModel):
             mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
 
     def load(self, weights):
-        ckpt = torch.load(weights, map_location='cpu')  # load checkpoint to CPU to avoid CUDA memory leak
-        csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
+        csd = weights['model'].float().state_dict()  # checkpoint state_dict as FP32
         csd = intersect_state_dicts(csd, self.state_dict())  # intersect
         self.load_state_dict(csd, strict=False)  # load
+        LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from {weights}')
 
 
 class SegmentationModel(DetectionModel):
@@ -203,11 +203,33 @@ class ClassificationModel(BaseModel):
         self.nc = nc
 
     def _from_yaml(self, cfg):
-        # Create a YOLOv5 classification model from a *.yaml file
+        # TODO: Create a YOLOv5 classification model from a *.yaml file
         self.model = None
 
     def load(self, weights):
-        ckpt = torch.load(weights, map_location='cpu')  # load checkpoint to CPU to avoid CUDA memory leak
-        csd = ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
+        model = weights["model"] if isinstance(weights, dict) else weights  # torchvision models are not dicts
+        csd = model.float().state_dict()
         csd = intersect_state_dicts(csd, self.state_dict())  # intersect
         self.load_state_dict(csd, strict=False)  # load
+
+    @staticmethod
+    def reshape_outputs(model, nc):
+        # Update a TorchVision classification model to class count 'n' if required
+        from ultralytics.yolo.utils.modeling.modules import Classify
+        name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1]  # last module
+        if isinstance(m, Classify):  # YOLO Classify() head
+            if m.linear.out_features != nc:
+                m.linear = nn.Linear(m.linear.in_features, nc)
+        elif isinstance(m, nn.Linear):  # ResNet, EfficientNet
+            if m.out_features != nc:
+                setattr(model, name, nn.Linear(m.in_features, nc))
+        elif isinstance(m, nn.Sequential):
+            types = [type(x) for x in m]
+            if nn.Linear in types:
+                i = types.index(nn.Linear)  # nn.Linear index
+                if m[i].out_features != nc:
+                    m[i] = nn.Linear(m[i].in_features, nc)
+            elif nn.Conv2d in types:
+                i = types.index(nn.Conv2d)  # nn.Conv2d index
+                if m[i].out_channels != nc:
+                    m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias)
diff --git a/ultralytics/yolo/v8/classify/train.py b/ultralytics/yolo/v8/classify/train.py
index 4027d878..4037b833 100644
--- a/ultralytics/yolo/v8/classify/train.py
+++ b/ultralytics/yolo/v8/classify/train.py
@@ -1,26 +1,27 @@
-import subprocess
-import time
-from pathlib import Path
-
 import hydra
 import torch
 
 from ultralytics.yolo import v8
 from ultralytics.yolo.data import build_classification_dataloader
 from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
-from ultralytics.yolo.utils import colorstr
-from ultralytics.yolo.utils.downloads import download
-from ultralytics.yolo.utils.files import WorkingDirectory
-from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first
+from ultralytics.yolo.utils.modeling.tasks import ClassificationModel
 
 
-# BaseTrainer python usage
 class ClassificationTrainer(BaseTrainer):
 
+    def load_model(self, model_cfg, weights, data):
+        # TODO: why treat clf models as unique. We should have clf yamls?
+        if weights and not weights.__class__.__name__.startswith("yolo"):  # torchvision
+            model = weights
+        else:
+            model = ClassificationModel(model_cfg, weights, data["nc"])
+        ClassificationModel.reshape_outputs(model, data["nc"])
+        return model
+
     def get_dataloader(self, dataset_path, batch_size=None, rank=0):
         return build_classification_dataloader(path=dataset_path,
                                                imgsz=self.args.img_size,
-                                               batch_size=self.args.batch_size,
+                                               batch_size=batch_size,
                                                rank=rank)
 
     def preprocess_batch(self, batch):
diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py
index 548bae64..1dd64c87 100644
--- a/ultralytics/yolo/v8/segment/train.py
+++ b/ultralytics/yolo/v8/segment/train.py
@@ -10,12 +10,11 @@ import torch.nn.functional as F
 from ultralytics.yolo import v8
 from ultralytics.yolo.data import build_dataloader
 from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
-from ultralytics.yolo.utils.downloads import download
-from ultralytics.yolo.utils.files import WorkingDirectory
+from ultralytics.yolo.utils.anchors import check_anchors
 from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
 from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
 from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
-from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, de_parallel, torch_distributed_zero_first
+from ultralytics.yolo.utils.torch_utils import de_parallel
 
 
 # BaseTrainer python usage
@@ -45,8 +44,15 @@ class SegmentationTrainer(BaseTrainer):
         batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
         return batch
 
-    def load_cfg(self, cfg):
-        return SegmentationModel(cfg, nc=80)
+    def load_model(self, model_cfg, weights, data):
+        model = SegmentationModel(model_cfg if model_cfg else weights["model"].yaml,
+                                  ch=3,
+                                  nc=data["nc"],
+                                  anchors=self.args.get("anchors"))
+        check_anchors(model, self.args.anchor_t, self.args.img_size)
+        if weights:
+            model.load(weights)
+        return model
 
     def get_validator(self):
         return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)
@@ -232,7 +238,7 @@ class SegmentationTrainer(BaseTrainer):
 
 @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
 def train(cfg):
-    cfg.cfg = v8.ROOT / "models/yolov5n-seg.yaml"
+    cfg.model = v8.ROOT / "models/yolov5n-seg.yaml"
     cfg.data = cfg.data or "coco128-seg.yaml"  # or yolo.ClassificationDataset("mnist")
     trainer = SegmentationTrainer(cfg)
     trainer.train()