diff --git a/docs/hub.md b/docs/hub.md
index f7a7dbb2..1ae00b1c 100644
--- a/docs/hub.md
+++ b/docs/hub.md
@@ -1,30 +1,54 @@
 # Ultralytics HUB
 
-<div align="center">
-  <a href="https://hub.ultralytics.com" target="_blank">
-    <img width="1024" src="https://github.com/ultralytics/assets/raw/main/im/ultralytics-hub.png"></a>
+<a href="https://bit.ly/ultralytics_hub" target="_blank">
+<img width="100%" src="https://github.com/ultralytics/assets/raw/main/im/ultralytics-hub.png"></a>
 <br>
+<br>
+<div align="center">
+  <a href="https://github.com/ultralytics" style="text-decoration:none;">
+    <img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-github.png" width="2%" alt="" /></a>
+  <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="2%" alt="" />
+  <a href="https://www.linkedin.com/company/ultralytics" style="text-decoration:none;">
+    <img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-linkedin.png" width="2%" alt="" /></a>
+  <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="2%" alt="" />
+  <a href="https://twitter.com/ultralytics" style="text-decoration:none;">
+    <img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-twitter.png" width="2%" alt="" /></a>
+  <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="2%" alt="" />
+  <a href="https://www.producthunt.com/@glenn_jocher" style="text-decoration:none;">
+    <img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-producthunt.png" width="2%" alt="" /></a>
+  <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="2%" alt="" />
+  <a href="https://youtube.com/ultralytics" style="text-decoration:none;">
+    <img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-youtube.png" width="2%" alt="" /></a>
+  <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="2%" alt="" />
+  <a href="https://www.facebook.com/ultralytics" style="text-decoration:none;">
+    <img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-facebook.png" width="2%" alt="" /></a>
+  <img src="https://github.com/ultralytics/assets/raw/main/social/logo-transparent.png" width="2%" alt="" />
+  <a href="https://www.instagram.com/ultralytics/" style="text-decoration:none;">
+    <img src="https://github.com/ultralytics/assets/raw/main/social/logo-social-instagram.png" width="2%" alt="" /></a>
+  <br>
+  <br>
   <a href="https://github.com/ultralytics/hub/actions/workflows/ci.yaml">
     <img src="https://github.com/ultralytics/hub/actions/workflows/ci.yaml/badge.svg" alt="CI CPU"></a>
+  <a href="https://colab.research.google.com/github/ultralytics/hub/blob/master/hub.ipynb">
+    <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a> 
 </div>
-
+<br>
 
 
 [Ultralytics HUB](https://hub.ultralytics.com) is a new no-code online tool developed
 by [Ultralytics](https://ultralytics.com), the creators of the popular [YOLOv5](https://github.com/ultralytics/yolov5)
-object detection and image segmentation models. With Ultralytics HUB, users can easily train and deploy YOLOv5 models
+object detection and image segmentation models. With Ultralytics HUB, users can easily train and deploy YOLO models
 without any coding or technical expertise.
 
 Ultralytics HUB is designed to be user-friendly and intuitive, with a drag-and-drop interface that allows users to
 easily upload their data and select their model configurations. It also offers a range of pre-trained models and
 templates to choose from, making it easy for users to get started with training their own models. Once a model is
 trained, it can be easily deployed and used for real-time object detection and image segmentation tasks. Overall,
-Ultralytics HUB is an essential tool for anyone looking to use YOLOv5 for their object detection and image segmentation
+Ultralytics HUB is an essential tool for anyone looking to use YOLO for their object detection and image segmentation
 projects.
 
 **[Get started now](https://hub.ultralytics.com)** and experience the power and simplicity of Ultralytics HUB for
-yourself. Sign up for a free account and
-start building, training, and deploying YOLOv5 and YOLOv8 models today.
+yourself. Sign up for a free account and start building, training, and deploying YOLOv5 and YOLOv8 models today.
 
 ## 1. Upload a Dataset
 
@@ -44,7 +68,9 @@ zip -r coco6.zip coco6
 The example [coco6.zip](https://github.com/ultralytics/hub/blob/master/coco6.zip) dataset in this repository can be
 downloaded and unzipped to see exactly how to structure your custom dataset.
 
-<p align="center"><img width="80%" src="https://user-images.githubusercontent.com/26833433/201424843-20fa081b-ad4b-4d6c-a095-e810775908d8.png" title="COCO6" /></p>
+<p align="center">
+<img width="80%" src="https://user-images.githubusercontent.com/26833433/201424843-20fa081b-ad4b-4d6c-a095-e810775908d8.png" title="COCO6" />
+</p>
 
 The dataset YAML is the same standard YOLOv5 YAML format. See
 the [YOLOv5 Train Custom Data tutorial](https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data) for full details.
@@ -68,20 +94,21 @@ names:
 After zipping your dataset, sign in to [Ultralytics HUB](https://bit.ly/ultralytics_hub) and click the Datasets tab.
 Click 'Upload Dataset' to upload, scan and visualize your new dataset before training new YOLOv5 models on it!
 
-<img width="100%" alt="HUB Dataset Upload" src="https://user-images.githubusercontent.com/26833433/198611715-540c9856-49d7-4069-a2fd-7c9eb70e772e.png">
+<img width="100%" alt="HUB Dataset Upload" src="https://user-images.githubusercontent.com/26833433/216763338-9a8812c8-a4e5-4362-8102-40dad7818396.png">
 
 ## 2. Train a Model
 
-Connect to the Ultralytics HUB notebook and use your model API key to begin
-training! <a href="https://colab.research.google.com/github/ultralytics/hub/blob/master/hub.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
+Connect to the Ultralytics HUB notebook and use your model API key to begin training! 
+
+<a href="https://colab.research.google.com/github/ultralytics/hub/blob/master/hub.ipynb" target="_blank">
+<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
 
 ## 3. Deploy to Real World
 
 Export your model to 13 different formats, including TensorFlow, ONNX, OpenVINO, CoreML, Paddle and many others. Run
-models directly on your mobile device by downloading the [Ultralytics App](https://ultralytics.com/app_install)!
-
-<a href="https://ultralytics.com/app_install" target="_blank">
-<img width="100%" alt="Ultralytics mobile app" src="https://github.com/ultralytics/assets/raw/main/im/ultralytics-app.png"></a>
+models directly on your [iOS](https://apps.apple.com/xk/app/ultralytics/id1583935240) or 
+[Android](https://play.google.com/store/apps/details?id=com.ultralytics.ultralytics_app) mobile device by downloading 
+the [Ultralytics App](https://ultralytics.com/app_install)!
 
 ## ❓ Issues
 
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index a5841b5f..6520f00d 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
 # Ultralytics YOLO 🚀, GPL-3.0 license
 
-__version__ = "8.0.31"
+__version__ = "8.0.32"
 
 from ultralytics.yolo.engine.model import YOLO
 from ultralytics.yolo.utils import ops
diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py
index 883d06a9..f6e87855 100644
--- a/ultralytics/hub/session.py
+++ b/ultralytics/hub/session.py
@@ -12,7 +12,7 @@ from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_
 from ultralytics.yolo.utils import is_colab, threaded, LOGGER, emojis, PREFIX
 from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
 
-AGENT_NAME = (f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local")
+AGENT_NAME = f"python-{__version__}-colab" if is_colab() else f"python-{__version__}-local"
 session = None
 
 
@@ -95,7 +95,8 @@ class HubTrainingSession:
 
             if data.get("status", None) == "trained":
                 raise ValueError(
-                    emojis(f"Model trained. View model at https://hub.ultralytics.com/models/{self.model_id} 🚀"))
+                    emojis(f"Model is already trained and uploaded to "
+                           f"https://hub.ultralytics.com/models/{self.model_id} 🚀"))
 
             if not data.get("data", None):
                 raise ValueError("Dataset may still be processing. Please wait a minute and try again.")  # RF fix
diff --git a/ultralytics/hub/utils.py b/ultralytics/hub/utils.py
index eec139f1..f2cff500 100644
--- a/ultralytics/hub/utils.py
+++ b/ultralytics/hub/utils.py
@@ -190,5 +190,4 @@ class Traces:
 
 
 # Run below code on hub/utils init -------------------------------------------------------------------------------------
-
 traces = Traces()
diff --git a/ultralytics/yolo/cfg/__init__.py b/ultralytics/yolo/cfg/__init__.py
index ec1885fa..33a2e4cc 100644
--- a/ultralytics/yolo/cfg/__init__.py
+++ b/ultralytics/yolo/cfg/__init__.py
@@ -49,19 +49,19 @@ CLI_HELP_MSG = \
     GitHub: https://github.com/ultralytics/ultralytics
     """
 
-CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl'}
+CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'}
 CFG_FRACTION_KEYS = {
     'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'fl_gamma',
-    'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'degrees', 'translate', 'scale', 'shear', 'perspective', 'flipud',
-    'fliplr', 'mosaic', 'mixup', 'copy_paste', 'conf', 'iou'}
+    'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic',
+    'mixup', 'copy_paste', 'conf', 'iou'}
 CFG_INT_KEYS = {
     'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
     'line_thickness', 'workspace', 'nbs'}
 CFG_BOOL_KEYS = {
-    'save', 'cache', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect',
-    'cos_lr', 'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt',
-    'save_conf', 'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks',
-    'boxes', 'keras', 'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader'}
+    'save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', 'single_cls', 'image_weights', 'rect', 'cos_lr',
+    'overlap_mask', 'val', 'save_json', 'save_hybrid', 'half', 'dnn', 'plots', 'show', 'save_txt', 'save_conf',
+    'save_crop', 'hide_labels', 'hide_conf', 'visualize', 'augment', 'agnostic_nms', 'retina_masks', 'boxes', 'keras',
+    'optimize', 'int8', 'dynamic', 'simplify', 'nms', 'v5loader'}
 
 
 def cfg2dict(cfg):
diff --git a/ultralytics/yolo/data/base.py b/ultralytics/yolo/data/base.py
index 06347fa0..da321ccf 100644
--- a/ultralytics/yolo/data/base.py
+++ b/ultralytics/yolo/data/base.py
@@ -28,7 +28,6 @@ class BaseDataset(Dataset):
         self,
         img_path,
         imgsz=640,
-        label_path=None,
         cache=False,
         augment=True,
         hyp=None,
@@ -42,7 +41,6 @@ class BaseDataset(Dataset):
         super().__init__()
         self.img_path = img_path
         self.imgsz = imgsz
-        self.label_path = label_path
         self.augment = augment
         self.single_cls = single_cls
         self.prefix = prefix
diff --git a/ultralytics/yolo/data/build.py b/ultralytics/yolo/data/build.py
index 3448232e..4cd59832 100644
--- a/ultralytics/yolo/data/build.py
+++ b/ultralytics/yolo/data/build.py
@@ -61,7 +61,7 @@ def seed_worker(worker_id):
     random.seed(worker_seed)
 
 
-def build_dataloader(cfg, batch_size, img_path, stride=32, rect=False, label_path=None, rank=-1, mode="train"):
+def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode="train"):
     assert mode in ["train", "val"]
     shuffle = mode == "train"
     if cfg.rect and shuffle:
@@ -70,9 +70,8 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, rect=False, label_pat
     with torch_distributed_zero_first(rank):  # init dataset *.cache only once if DDP
         dataset = YOLODataset(
             img_path=img_path,
-            label_path=label_path,
             imgsz=cfg.imgsz,
-            batch_size=batch_size,
+            batch_size=batch,
             augment=mode == "train",  # augmentation
             hyp=cfg,  # TODO: probably add a get_hyps_from_cfg function
             rect=cfg.rect or rect,  # rectangular batches
@@ -82,18 +81,19 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, rect=False, label_pat
             pad=0.0 if mode == "train" else 0.5,
             prefix=colorstr(f"{mode}: "),
             use_segments=cfg.task == "segment",
-            use_keypoints=cfg.task == "keypoint")
+            use_keypoints=cfg.task == "keypoint",
+            names=names)
 
-    batch_size = min(batch_size, len(dataset))
+    batch = min(batch, len(dataset))
     nd = torch.cuda.device_count()  # number of CUDA devices
     workers = cfg.workers if mode == "train" else cfg.workers * 2
-    nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])  # number of workers
+    nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers])  # number of workers
     sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
     loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader  # allow attribute updates
     generator = torch.Generator()
     generator.manual_seed(6148914691236517205 + RANK)
     return loader(dataset=dataset,
-                  batch_size=batch_size,
+                  batch_size=batch,
                   shuffle=shuffle and sampler is None,
                   num_workers=nw,
                   sampler=sampler,
diff --git a/ultralytics/yolo/data/dataset.py b/ultralytics/yolo/data/dataset.py
index a6f52018..fc58f6c4 100644
--- a/ultralytics/yolo/data/dataset.py
+++ b/ultralytics/yolo/data/dataset.py
@@ -14,7 +14,7 @@ from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image
 
 
 class YOLODataset(BaseDataset):
-    cache_version = 1.0  # dataset labels *.cache version, >= 1.0 for YOLOv8
+    cache_version = '1.0.1'  # dataset labels *.cache version, >= 1.0.0 for YOLOv8
     rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
     """YOLO Dataset.
     Args:
@@ -22,28 +22,26 @@ class YOLODataset(BaseDataset):
         prefix (str): prefix.
     """
 
-    def __init__(
-        self,
-        img_path,
-        imgsz=640,
-        label_path=None,
-        cache=False,
-        augment=True,
-        hyp=None,
-        prefix="",
-        rect=False,
-        batch_size=None,
-        stride=32,
-        pad=0.0,
-        single_cls=False,
-        use_segments=False,
-        use_keypoints=False,
-    ):
+    def __init__(self,
+                 img_path,
+                 imgsz=640,
+                 cache=False,
+                 augment=True,
+                 hyp=None,
+                 prefix="",
+                 rect=False,
+                 batch_size=None,
+                 stride=32,
+                 pad=0.0,
+                 single_cls=False,
+                 use_segments=False,
+                 use_keypoints=False,
+                 names=None):
         self.use_segments = use_segments
         self.use_keypoints = use_keypoints
+        self.names = names
         assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
-        super().__init__(img_path, imgsz, label_path, cache, augment, hyp, prefix, rect, batch_size, stride, pad,
-                         single_cls)
+        super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls)
 
     def cache_labels(self, path=Path("./labels.cache")):
         # Cache dataset labels, check images and read shapes
@@ -56,7 +54,7 @@ class YOLODataset(BaseDataset):
         with ThreadPool(NUM_THREADS) as pool:
             results = pool.imap(func=verify_image_label,
                                 iterable=zip(self.im_files, self.label_files, repeat(self.prefix),
-                                             repeat(self.use_keypoints)))
+                                             repeat(self.use_keypoints), repeat(len(self.names))))
             pbar = tqdm(results, desc=desc, total=total, bar_format=TQDM_BAR_FORMAT)
             for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
                 nm += nm_f
diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py
index 91d8fe02..e9ec668a 100644
--- a/ultralytics/yolo/data/utils.py
+++ b/ultralytics/yolo/data/utils.py
@@ -61,7 +61,7 @@ def exif_size(img):
 
 def verify_image_label(args):
     # Verify one image-label pair
-    im_file, lb_file, prefix, keypoint = args
+    im_file, lb_file, prefix, keypoint, num_cls = args
     # number (missing, found, empty, corrupt), message, segments, keypoints
     nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
     try:
@@ -97,16 +97,20 @@ def verify_image_label(args):
                     assert (lb[:, 6::3] <= 1).all(), "non-normalized or out of bounds coordinate labels"
                     kpts = np.zeros((lb.shape[0], 39))
                     for i in range(len(lb)):
-                        kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5,
-                                                             3))  # remove the occlusion parameter from the GT
+                        kpt = np.delete(lb[i, 5:], np.arange(2, lb.shape[1] - 5, 3))  # remove occlusion param from GT
                         kpts[i] = np.hstack((lb[i, :5], kpt))
                     lb = kpts
                     assert lb.shape[1] == 39, "labels require 39 columns each after removing occlusion parameter"
                 else:
                     assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
-                    assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
-                    assert (lb[:, 1:] <=
-                            1).all(), f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
+                    assert (lb[:, 1:] <= 1).all(), \
+                        f"non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}"
+                # All labels
+                max_cls = int(lb[:, 0].max())  # max label count
+                assert max_cls <= num_cls, \
+                    f'Label class {max_cls} exceeds dataset class count {num_cls}. ' \
+                    f'Possible class labels are 0-{num_cls - 1}'
+                assert (lb >= 0).all(), f"negative label values {lb[lb < 0]}"
                 _, i = np.unique(lb, axis=0, return_index=True)
                 if len(i) < nl:  # duplicate row check
                     lb = lb[i]  # remove duplicates
@@ -192,8 +196,8 @@ def check_det_dataset(dataset, autodownload=True):
     # Download (optional)
     extract_dir = ''
     if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
-        download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1)
-        data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
+        new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False)
+        data = next((DATASETS_DIR / new_dir).rglob('*.yaml'))
         extract_dir, autodownload = data.parent, False
 
     # Read yaml (optional)
diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py
index a70a68a3..442ae171 100644
--- a/ultralytics/yolo/engine/exporter.py
+++ b/ultralytics/yolo/engine/exporter.py
@@ -203,7 +203,7 @@ class Exporter:
         self.im = im
         self.model = model
         self.file = file
-        self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else (x.shape for x in y)
+        self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
         self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
         self.metadata = {
             'description': f"Ultralytics {self.pretty_name} model trained on {self.model.args['data']}",
@@ -213,8 +213,8 @@ class Exporter:
             'stride': int(max(model.stride)),
             'names': model.names}  # model metadata
 
-        LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} and "
-                    f"output shape {self.output_shape} ({file_size(file):.1f} MB)")
+        LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
+                    f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)")
 
         # Exports
         f = [''] * len(fmts)  # exported filenames
@@ -234,19 +234,22 @@ class Exporter:
             nms = False
             f[5], s_model = self._export_saved_model(nms=nms or self.args.agnostic_nms or tfjs,
                                                      agnostic_nms=self.args.agnostic_nms or tfjs)
-            if pb or tfjs:  # pb prerequisite to tfjs
-                f[6], _ = self._export_pb(s_model)
-            if tflite or edgetpu:
-                f[7], _ = self._export_tflite(s_model,
-                                              int8=self.args.int8 or edgetpu,
-                                              data=self.args.data,
-                                              nms=nms,
-                                              agnostic_nms=self.args.agnostic_nms)
-                if edgetpu:
-                    f[8], _ = self._export_edgetpu()
-                self._add_tflite_metadata(f[8] or f[7], num_outputs=len(self.output_shape))
-            if tfjs:
-                f[9], _ = self._export_tfjs()
+
+            debug = False
+            if debug:
+                if pb or tfjs:  # pb prerequisite to tfjs
+                    f[6], _ = self._export_pb(s_model)
+                if tflite or edgetpu:
+                    f[7], _ = self._export_tflite(s_model,
+                                                  int8=self.args.int8 or edgetpu,
+                                                  data=self.args.data,
+                                                  nms=nms,
+                                                  agnostic_nms=self.args.agnostic_nms)
+                    if edgetpu:
+                        f[8], _ = self._export_edgetpu()
+                    self._add_tflite_metadata(f[8] or f[7], num_outputs=len(self.output_shape))
+                if tfjs:
+                    f[9], _ = self._export_tfjs()
         if paddle:  # PaddlePaddle
             f[10], _ = self._export_paddle()
 
diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py
index b3e8f587..e57f3bb0 100644
--- a/ultralytics/yolo/engine/validator.py
+++ b/ultralytics/yolo/engine/validator.py
@@ -120,7 +120,7 @@ class BaseValidator:
             if not pt:
                 self.args.rect = False
             self.dataloader = self.dataloader or \
-                              self.get_dataloader(self.data.get("val") or self.data.set("test"), self.args.batch)
+                              self.get_dataloader(self.data.get("val") or self.data.get("test"), self.args.batch)
 
             model.eval()
             model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz))  # warmup
diff --git a/ultralytics/yolo/utils/downloads.py b/ultralytics/yolo/utils/downloads.py
index b7b6d746..1a5f49e1 100644
--- a/ultralytics/yolo/utils/downloads.py
+++ b/ultralytics/yolo/utils/downloads.py
@@ -39,6 +39,7 @@ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
         for f in zipObj.namelist():  # list all archived filenames in the zip
             if all(x not in f for x in exclude):
                 zipObj.extract(f, path=path)
+        return zipObj.namelist()[0]  # return unzip dir
 
 
 def safe_download(url,
@@ -112,13 +113,14 @@ def safe_download(url,
         unzip_dir = dir or f.parent  # unzip to dir if provided else unzip in place
         LOGGER.info(f'Unzipping {f} to {unzip_dir}...')
         if f.suffix == '.zip':
-            unzip_file(file=f, path=unzip_dir)  # unzip
+            unzip_dir = unzip_file(file=f, path=unzip_dir)  # unzip
         elif f.suffix == '.tar':
             subprocess.run(['tar', 'xf', f, '--directory', unzip_dir], check=True)  # unzip
         elif f.suffix == '.gz':
             subprocess.run(['tar', 'xfz', f, '--directory', unzip_dir], check=True)  # unzip
         if delete:
             f.unlink()  # remove zip
+        return unzip_dir
 
 
 def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
diff --git a/ultralytics/yolo/v8/detect/train.py b/ultralytics/yolo/v8/detect/train.py
index c5b3a8b7..c199d228 100644
--- a/ultralytics/yolo/v8/detect/train.py
+++ b/ultralytics/yolo/v8/detect/train.py
@@ -41,7 +41,7 @@ class DetectionTrainer(BaseTrainer):
                                  shuffle=mode == "train",
                                  seed=self.args.seed)[0] if self.args.v5loader else \
             build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode,
-                             rect=mode == "val")[0]
+                             rect=mode == "val", names=self.data['names'])[0]
 
     def preprocess_batch(self, batch):
         batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py
index bc2148dc..f093b229 100644
--- a/ultralytics/yolo/v8/detect/val.py
+++ b/ultralytics/yolo/v8/detect/val.py
@@ -176,7 +176,8 @@ class DetectionValidator(BaseValidator):
                                  prefix=colorstr(f'{self.args.mode}: '),
                                  shuffle=False,
                                  seed=self.args.seed)[0] if self.args.v5loader else \
-            build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, mode="val")[0]
+            build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, names=self.data['names'],
+                             mode="val")[0]
 
     def plot_val_samples(self, batch, ni):
         plot_images(batch["img"],