diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py
index b007de51..97959575 100644
--- a/ultralytics/yolo/engine/trainer.py
+++ b/ultralytics/yolo/engine/trainer.py
@@ -9,6 +9,7 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo
 import os
 import time
 from collections import defaultdict
+from copy import deepcopy
 from datetime import datetime
 from pathlib import Path
 from typing import Dict, Union
@@ -29,6 +30,7 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
 from ultralytics.yolo.utils.checks import print_args
 from ultralytics.yolo.utils.files import increment_path, save_yaml
 from ultralytics.yolo.utils.modeling import get_model
+from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel
 
 DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
 
@@ -63,6 +65,7 @@ class BaseTrainer:
         self.trainset, self.testset = self.get_dataset(self.data)
         if self.args.model:
             self.model = self.get_model(self.args.model)
+        self.ema = None
 
         # epoch level metrics
         self.metrics = {}  # handle metrics returned by validator
@@ -144,6 +147,7 @@ class BaseTrainer:
             self.validator = self.get_validator()
             print("created testloader :", rank)
             self.console.info(self.progress_string())
+            self.ema = ModelEMA(self.model)
 
     def _do_train(self, rank=-1, world_size=1):
         if world_size > 1:
@@ -196,6 +200,7 @@ class BaseTrainer:
             if rank in [-1, 0]:
                 # validation
                 # callback: on_val_start()
+                self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
                 self.validate()
                 # callback: on_val_end()
 
@@ -220,10 +225,10 @@ class BaseTrainer:
         ckpt = {
             'epoch': self.epoch,
             'best_fitness': self.best_fitness,
-            'model': None,  # deepcopy(ema.ema).half(),  # deepcopy(de_parallel(model)).half(),
-            'ema': None,  # deepcopy(ema.ema).half(),
-            'updates': None,  # ema.updates,
-            'optimizer': None,  # optimizer.state_dict(),
+            'model': deepcopy(de_parallel(self.model)).half(),
+            'ema': deepcopy(self.ema.ema).half(),
+            'updates': self.ema.updates,
+            'optimizer': self.optimizer.state_dict(),
             'train_args': self.args,
             'date': datetime.now().isoformat()}
 
@@ -266,6 +271,8 @@ class BaseTrainer:
         self.scaler.step(self.optimizer)
         self.scaler.update()
         self.optimizer.zero_grad()
+        if self.ema:
+            self.ema.update(self.model)
 
     def preprocess_batch(self, batch):
         """
diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py
index eeda7bf1..e60f0863 100644
--- a/ultralytics/yolo/engine/validator.py
+++ b/ultralytics/yolo/engine/validator.py
@@ -30,19 +30,16 @@ class BaseValidator:
         Supports validation of a pre-trained model if passed or a model being trained
         if trainer is passed (trainer gets priority).
         """
-        training = trainer is not None
-        self.training = training
-        # trainer = trainer or self.trainer_class.get_trainer()
-        assert training or model is not None, "Either trainer or model is needed for validation"
-        if training:
-            model = trainer.model
+        self.training = trainer is not None
+        if self.training:
+            model = trainer.ema.ema or trainer.model
             self.args.half &= self.device.type != 'cpu'
             # NOTE: half() inference in evaluation will make training stuck,
             # so I comment it out for now, I think we can reuse half mode after we add EMA.
-            # model = model.half() if self.args.half else model
+            model = model.half() if self.args.half else model.float()
         else:  # TODO: handle this when detectMultiBackend is supported
+            assert model is not None, "Either trainer or model is needed for validation"
             # model = DetectMultiBacked(model)
-            pass
             # TODO: implement init_model_attributes()
 
         model.eval()
@@ -50,7 +47,7 @@ class BaseValidator:
         loss = 0
         n_batches = len(self.dataloader)
         desc = self.get_desc()
-        bar = tqdm(self.dataloader, desc, n_batches, not training, bar_format=TQDM_BAR_FORMAT)
+        bar = tqdm(self.dataloader, desc, n_batches, not self.training, bar_format=TQDM_BAR_FORMAT)
         self.init_metrics(de_parallel(model))
         with torch.no_grad():
             for batch_i, batch in enumerate(bar):
@@ -67,7 +64,7 @@ class BaseValidator:
 
                 # loss
                 with dt[2]:
-                    if training:
+                    if self.training:
                         loss += trainer.criterion(preds, batch)[0]
 
                 # pre-process predictions
@@ -82,7 +79,7 @@ class BaseValidator:
         self.print_results()
 
         # print speeds
-        if not training:
+        if not self.training:
             t = tuple(x.t / len(self.dataloader.dataset.samples) * 1E3 for x in dt)  # speeds per image
             # shape = (self.dataloader.batch_size, 3, imgsz, imgsz)
             self.logger.info(
diff --git a/ultralytics/yolo/utils/modeling/tasks.py b/ultralytics/yolo/utils/modeling/tasks.py
index c6c82b54..70ef3542 100644
--- a/ultralytics/yolo/utils/modeling/tasks.py
+++ b/ultralytics/yolo/utils/modeling/tasks.py
@@ -232,4 +232,4 @@ class ClassificationModel(BaseModel):
             elif nn.Conv2d in types:
                 i = types.index(nn.Conv2d)  # nn.Conv2d index
                 if m[i].out_channels != nc:
-                    m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias)
+                    m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py
index 795e572c..3dec5810 100644
--- a/ultralytics/yolo/utils/torch_utils.py
+++ b/ultralytics/yolo/utils/torch_utils.py
@@ -192,3 +192,34 @@ def is_parallel(model):
 def de_parallel(model):
     # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
     return model.module if is_parallel(model) else model
+
+
+class ModelEMA:
+    """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
+    Keeps a moving average of everything in the model state_dict (parameters and buffers)
+    For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
+    """
+
+    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
+        # Create EMA
+        self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA
+        self.updates = updates  # number of EMA updates
+        self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
+        for p in self.ema.parameters():
+            p.requires_grad_(False)
+
+    def update(self, model):
+        # Update EMA parameters
+        self.updates += 1
+        d = self.decay(self.updates)
+
+        msd = de_parallel(model).state_dict()  # model state_dict
+        for k, v in self.ema.state_dict().items():
+            if v.dtype.is_floating_point:  # true for FP16 and FP32
+                v *= d
+                v += (1 - d) * msd[k].detach()
+        # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
+
+    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
+        # Update EMA attributes
+        copy_attr(self.ema, model, include, exclude)
diff --git a/ultralytics/yolo/v8/segment/train.py b/ultralytics/yolo/v8/segment/train.py
index 8372b1b3..9949629f 100644
--- a/ultralytics/yolo/v8/segment/train.py
+++ b/ultralytics/yolo/v8/segment/train.py
@@ -159,11 +159,11 @@ class SegmentationTrainer(BaseTrainer):
 
             return tcls, tbox, indices, anch, tidxs, xywhn
 
-        if self.model.training:
+        if len(preds) == 2:  # eval
             p, proto, = preds
-        else:
-            p, proto, train_out = preds
-            p = train_out
+        else:  # len(3) train
+            _, proto, p = preds
+
         targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
         masks = batch["masks"]
         targets, masks = targets.to(self.device), masks.to(self.device).float()
diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py
index f4a526f7..372f3067 100644
--- a/ultralytics/yolo/v8/segment/val.py
+++ b/ultralytics/yolo/v8/segment/val.py
@@ -1,5 +1,4 @@
 import os
-from pathlib import Path
 
 import numpy as np
 import torch