diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 78148249..9e1b2b08 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -118,9 +118,9 @@ jobs:
         run: |
           yolo checks
           pip list
-      #      - name: Benchmark DetectionModel
-      #        shell: bash
-      #        run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}.pt' imgsz=160 verbose=0.318
+      - name: Benchmark World DetectionModel
+        shell: bash
+        run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/yolov8s-worldv2.pt' imgsz=160 verbose=0.318
       - name: Benchmark SegmentationModel
         shell: bash
         run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}-seg.pt' imgsz=160 verbose=0.281
diff --git a/docs/en/models/yolo-world.md b/docs/en/models/yolo-world.md
index 954b5dd1..116d62df 100644
--- a/docs/en/models/yolo-world.md
+++ b/docs/en/models/yolo-world.md
@@ -36,21 +36,29 @@ This section details the models available with their specific pre-trained weight
 
     All the YOLOv8-World weights have been directly migrated from the official [YOLO-World](https://github.com/AILab-CVC/YOLO-World) repository, highlighting their excellent contributions.
 
-| Model Type    | Pre-trained Weights                                                                                 | Tasks Supported                        | Inference | Validation | Training | Export |
-|---------------|-----------------------------------------------------------------------------------------------------|----------------------------------------|-----------|------------|----------|--------|
-| YOLOv8s-world | [yolov8s-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8s-world.pt) | [Object Detection](../tasks/detect.md) | ✅         | ✅          | ❌        | ❌      |
-| YOLOv8m-world | [yolov8m-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8m-world.pt) | [Object Detection](../tasks/detect.md) | ✅         | ✅          | ❌        | ❌      |
-| YOLOv8l-world | [yolov8l-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8l-world.pt) | [Object Detection](../tasks/detect.md) | ✅         | ✅          | ❌        | ❌      |
-| YOLOv8x-world | [yolov8x-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8x-world.pt) | [Object Detection](../tasks/detect.md) | ✅         | ✅          | ❌        | ❌      |
+| Model Type      | Pre-trained Weights                                                                                   | Tasks Supported                        | Inference | Validation | Training | Export |
+|-----------------|-------------------------------------------------------------------------------------------------------|----------------------------------------|-----------|------------|----------|--------|
+| YOLOv8s-world   | [yolov8s-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8s-world.pt)   | [Object Detection](../tasks/detect.md) | ✅        | ✅         | ❌       | ❌     |
+| YOLOv8s-worldv2 | [yolov8s-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8s-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅        | ✅         | ❌       | ✅     |
+| YOLOv8m-world   | [yolov8m-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8m-world.pt)   | [Object Detection](../tasks/detect.md) | ✅        | ✅         | ❌       | ❌     |
+| YOLOv8m-worldv2 | [yolov8m-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8m-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅        | ✅         | ❌       | ✅     |
+| YOLOv8l-world   | [yolov8l-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8l-world.pt)   | [Object Detection](../tasks/detect.md) | ✅        | ✅         | ❌       | ❌     |
+| YOLOv8l-worldv2 | [yolov8l-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8l-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅        | ✅         | ❌       | ✅     |
+| YOLOv8x-world   | [yolov8x-world.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8x-world.pt)   | [Object Detection](../tasks/detect.md) | ✅        | ✅         | ❌       | ❌     |
+| YOLOv8x-worldv2 | [yolov8x-worldv2.pt](https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8x-worldv2.pt) | [Object Detection](../tasks/detect.md) | ✅        | ✅         | ❌       | ✅     |
 
 ## Zero-shot Transfer on COCO Dataset
 
-| Model Type    | mAP  | mAP50 | mAP75 |
-|---------------|------|-------|-------|
-| yolov8s-world | 37.4 | 52.0  | 40.6  |
-| yolov8m-world | 42.0 | 57.0  | 45.6  |
-| yolov8l-world | 45.7 | 61.3  | 49.8  |
-| yolov8x-world | 47.0 | 63.0  | 51.2  |
+| Model Type      | mAP  | mAP50 | mAP75 |
+|-----------------|------|-------|-------|
+| yolov8s-world   | 37.4 | 52.0  | 40.6  |
+| yolov8s-worldv2 | 37.7 | 52.2  | 41.0  |
+| yolov8m-world   | 42.0 | 57.0  | 45.6  |
+| yolov8m-worldv2 | 43.0 | 58.4  | 46.8  |
+| yolov8l-world   | 45.7 | 61.3  | 49.8  |
+| yolov8l-worldv2 | 45.8 | 61.3  | 49.8  |
+| yolov8x-world   | 47.0 | 63.0  | 51.2  |
+| yolov8x-worldv2 | 47.1 | 62.8  | 51.4  |
 
 ## Usage Examples
 
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 666aea71..81a263ef 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
 # Ultralytics YOLO 🚀, AGPL-3.0 license
 
-__version__ = "8.1.20"
+__version__ = "8.1.21"
 
 from ultralytics.data.explorer.explorer import Explorer
 from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
diff --git a/ultralytics/cfg/models/v8/yolov8-world.yaml b/ultralytics/cfg/models/v8/yolov8-world.yaml
index 611ea1a9..c21a7f00 100644
--- a/ultralytics/cfg/models/v8/yolov8-world.yaml
+++ b/ultralytics/cfg/models/v8/yolov8-world.yaml
@@ -1,5 +1,5 @@
 # Ultralytics YOLO 🚀, AGPL-3.0 license
-# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
+# YOLOv8-World object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/tasks/detect
 
 # Parameters
 nc: 80 # number of classes
diff --git a/ultralytics/cfg/models/v8/yolov8-world-t2i.yaml b/ultralytics/cfg/models/v8/yolov8-worldv2.yaml
similarity index 83%
rename from ultralytics/cfg/models/v8/yolov8-world-t2i.yaml
rename to ultralytics/cfg/models/v8/yolov8-worldv2.yaml
index 6b654adb..322b97d4 100644
--- a/ultralytics/cfg/models/v8/yolov8-world-t2i.yaml
+++ b/ultralytics/cfg/models/v8/yolov8-worldv2.yaml
@@ -1,5 +1,5 @@
 # Ultralytics YOLO 🚀, AGPL-3.0 license
-# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
+# YOLOv8-World-v2 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/tasks/detect
 
 # Parameters
 nc: 80 # number of classes
@@ -29,18 +29,18 @@ backbone:
 head:
   - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
   - [[-1, 6], 1, Concat, [1]] # cat backbone P4
-  - [-1, 2, C2fAttn, [512, 256, 8]] # 12
+  - [-1, 3, C2fAttn, [512, 256, 8]] # 12
 
   - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
   - [[-1, 4], 1, Concat, [1]] # cat backbone P3
-  - [-1, 2, C2fAttn, [256, 128, 4]] # 15 (P3/8-small)
+  - [-1, 3, C2fAttn, [256, 128, 4]] # 15 (P3/8-small)
 
   - [15, 1, Conv, [256, 3, 2]]
   - [[-1, 12], 1, Concat, [1]] # cat head P4
-  - [-1, 2, C2fAttn, [512, 256, 8]] # 18 (P4/16-medium)
+  - [-1, 3, C2fAttn, [512, 256, 8]] # 18 (P4/16-medium)
 
   - [-1, 1, Conv, [512, 3, 2]]
   - [[-1, 9], 1, Concat, [1]] # cat head P5
-  - [-1, 2, C2fAttn, [1024, 512, 16]] # 21 (P5/32-large)
+  - [-1, 3, C2fAttn, [1024, 512, 16]] # 21 (P5/32-large)
 
   - [[15, 18, 21], 1, WorldDetect, [nc, 512, True]] # Detect(P3, P4, P5)
diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py
index 9dae5278..173983e2 100644
--- a/ultralytics/engine/exporter.py
+++ b/ultralytics/engine/exporter.py
@@ -68,7 +68,7 @@ from ultralytics.data.dataset import YOLODataset
 from ultralytics.data.utils import check_det_dataset
 from ultralytics.nn.autobackend import check_class_names, default_class_names
 from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
-from ultralytics.nn.tasks import DetectionModel, SegmentationModel
+from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel
 from ultralytics.utils import (
     ARM64,
     DEFAULT_CFG,
@@ -201,6 +201,14 @@ class Exporter:
             assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
         if edgetpu and not LINUX:
             raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler/")
+        print(type(model))
+        if isinstance(model, WorldModel):
+            LOGGER.warning(
+                "WARNING ⚠️ YOLOWorld (original version) export is not supported to any format.\n"
+                "WARNING ⚠️ YOLOWorldv2 models (i.e. 'yolov8s-worldv2.pt') only support export to "
+                "(torchscript, onnx, openvino, engine, coreml) formats. "
+                "See https://docs.ultralytics.com/models/yolo-world for details."
+            )
 
         # Input
         im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device)
@@ -252,9 +260,10 @@ class Exporter:
         self.metadata = {
             "description": description,
             "author": "Ultralytics",
-            "license": "AGPL-3.0 https://ultralytics.com/license",
             "date": datetime.now().isoformat(),
             "version": __version__,
+            "license": "AGPL-3.0 License (https://ultralytics.com/license)",
+            "docs": "https://docs.ultralytics.com",
             "stride": int(max(model.stride)),
             "task": model.task,
             "batch": self.args.batch,
diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py
index 061dfe5f..dd000cf1 100644
--- a/ultralytics/engine/model.py
+++ b/ultralytics/engine/model.py
@@ -295,7 +295,7 @@ class Model(nn.Module):
         self.model.load(weights)
         return self
 
-    def save(self, filename: Union[str, Path] = "saved_model.pt") -> None:
+    def save(self, filename: Union[str, Path] = "saved_model.pt", use_dill=True) -> None:
         """
         Saves the current model state to a file.
 
@@ -303,12 +303,22 @@ class Model(nn.Module):
 
         Args:
             filename (str | Path): The name of the file to save the model to. Defaults to 'saved_model.pt'.
+            use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
 
         Raises:
             AssertionError: If the model is not a PyTorch model.
         """
         self._check_is_pytorch_model()
-        torch.save(self.ckpt, filename)
+        from ultralytics import __version__
+        from datetime import datetime
+
+        updates = {
+            "date": datetime.now().isoformat(),
+            "version": __version__,
+            "license": "AGPL-3.0 License (https://ultralytics.com/license)",
+            "docs": "https://docs.ultralytics.com",
+        }
+        torch.save({**self.ckpt, **updates}, filename, use_dill=use_dill)
 
     def info(self, detailed: bool = False, verbose: bool = True):
         """
diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py
index 33821171..f005f341 100644
--- a/ultralytics/engine/trainer.py
+++ b/ultralytics/engine/trainer.py
@@ -488,6 +488,8 @@ class BaseTrainer:
             "train_results": results,
             "date": datetime.now().isoformat(),
             "version": __version__,
+            "license": "AGPL-3.0 (https://ultralytics.com/license)",
+            "docs": "https://docs.ultralytics.com",
         }
 
         # Save last and best
diff --git a/ultralytics/models/yolo/model.py b/ultralytics/models/yolo/model.py
index 44b0d9e8..5a2dc24f 100644
--- a/ultralytics/models/yolo/model.py
+++ b/ultralytics/models/yolo/model.py
@@ -13,8 +13,8 @@ class YOLO(Model):
 
     def __init__(self, model="yolov8n.pt", task=None, verbose=False):
         """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
-        stem = Path(model).stem  # filename stem without suffix, i.e. "yolov8n"
-        if "-world" in stem:
+        model = Path(model)
+        if "-world" in model.stem and model.suffix in {".pt", ".yaml", ".yml"}:  # if YOLOWorld PyTorch model
             new_instance = YOLOWorld(model)
             self.__class__ = type(new_instance)
             self.__dict__ = new_instance.__dict__
@@ -67,7 +67,7 @@ class YOLOWorld(Model):
         Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats.
 
         Args:
-            model (str): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'.
+            model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'.
         """
         super().__init__(model=model, task="detect")
 
diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py
index b98448da..3bc63510 100644
--- a/ultralytics/utils/benchmarks.py
+++ b/ultralytics/utils/benchmarks.py
@@ -32,7 +32,7 @@ from pathlib import Path
 import numpy as np
 import torch.cuda
 
-from ultralytics import YOLO
+from ultralytics import YOLO, YOLOWorld
 from ultralytics.cfg import TASK2DATA, TASK2METRIC
 from ultralytics.engine.exporter import export_formats
 from ultralytics.utils import ASSETS, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR
@@ -84,14 +84,20 @@ def benchmark(
         emoji, filename = "❌", None  # export defaults
         try:
             # Checks
-            if i == 9:
+            if i == 9:  # Edge TPU
                 assert LINUX, "Edge TPU export only supported on Linux"
-            elif i == 7:
+            elif i == 7:  # TF GraphDef
                 assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
             elif i in {5, 10}:  # CoreML and TF.js
                 assert MACOS or LINUX, "export only supported on macOS and Linux"
             if i in {3, 5}:  # CoreML and OpenVINO
                 assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12"
+            if i in {6, 7, 8, 9, 10}:  # All TF formats
+                assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
+            if i in {11}:  # Paddle
+                assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
+            if i in {12}:  # NCNN
+                assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
             if "cpu" in device.type:
                 assert cpu, "inference not supported on CPU"
             if "cuda" in device.type:
@@ -261,7 +267,8 @@ class ProfileModels:
         """
         return 0.0, 0.0, 0.0, 0.0  # return (num_layers, num_params, num_gradients, num_flops)
 
-    def iterative_sigma_clipping(self, data, sigma=2, max_iters=3):
+    @staticmethod
+    def iterative_sigma_clipping(data, sigma=2, max_iters=3):
         """Applies an iterative sigma clipping algorithm to the given data times number of iterations."""
         data = np.array(data)
         for _ in range(max_iters):
@@ -359,9 +366,13 @@ class ProfileModels:
     def generate_table_row(self, model_name, t_onnx, t_engine, model_info):
         """Generates a formatted string for a table row that includes model performance and metric details."""
         layers, params, gradients, flops = model_info
-        return f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± {t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |"
+        return (
+            f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.2f} ± {t_onnx[1]:.2f} ms | {t_engine[0]:.2f} ± "
+            f"{t_engine[1]:.2f} ms | {params / 1e6:.1f} | {flops:.1f} |"
+        )
 
-    def generate_results_dict(self, model_name, t_onnx, t_engine, model_info):
+    @staticmethod
+    def generate_results_dict(model_name, t_onnx, t_engine, model_info):
         """Generates a dictionary of model details including name, parameters, GFLOPS and speed metrics."""
         layers, params, gradients, flops = model_info
         return {
@@ -372,11 +383,18 @@ class ProfileModels:
             "model/speed_TensorRT(ms)": round(t_engine[0], 3),
         }
 
-    def print_table(self, table_rows):
+    @staticmethod
+    def print_table(table_rows):
         """Formats and prints a comparison table for different models with given statistics and performance data."""
         gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
-        header = f"| Model | size
(pixels) | mAPval
50-95 | Speed
CPU ONNX
(ms) | Speed
{gpu} TensorRT
(ms) | params
(M) | FLOPs
(B) |"
-        separator = "|-------------|---------------------|--------------------|------------------------------|-----------------------------------|------------------|-----------------|"
+        header = (
+            f"| Model | size
(pixels) | mAPval
50-95 | Speed
CPU ONNX
(ms) | "
+            f"Speed
{gpu} TensorRT
(ms) | params
(M) | FLOPs
(B) |"
+        )
+        separator = (
+            "|-------------|---------------------|--------------------|------------------------------|"
+            "-----------------------------------|------------------|-----------------|"
+        )
 
         print(f"\n\n{header}")
         print(separator)
diff --git a/ultralytics/utils/downloads.py b/ultralytics/utils/downloads.py
index 21314562..470fa83d 100644
--- a/ultralytics/utils/downloads.py
+++ b/ultralytics/utils/downloads.py
@@ -20,7 +20,8 @@ GITHUB_ASSETS_NAMES = (
     [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
     + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
     + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
-    + [f"yolov8{k}-world.pt" for k in "sml"]
+    + [f"yolov8{k}-world.pt" for k in "smlx"]
+    + [f"yolov8{k}-worldv2.pt" for k in "smlx"]
     + [f"yolo_nas_{k}.pt" for k in "sml"]
     + [f"sam_{k}.pt" for k in "bl"]
     + [f"FastSAM-{k}.pt" for k in "sx"]
diff --git a/ultralytics/utils/patches.py b/ultralytics/utils/patches.py
index 703ec19d..acbf5a99 100644
--- a/ultralytics/utils/patches.py
+++ b/ultralytics/utils/patches.py
@@ -60,27 +60,29 @@ def imshow(winname: str, mat: np.ndarray):
 _torch_save = torch.save  # copy to avoid recursion errors
 
 
-def torch_save(*args, **kwargs):
+def torch_save(*args, use_dill=True, **kwargs):
     """
-    Use dill (if exists) to serialize the lambda functions where pickle does not do this. Also adds 3 retries with
-    exponential standoff in case of save failure to improve robustness to transient issues.
+    Optionally use dill to serialize lambda functions where pickle does not, adding robustness with 3 retries and
+    exponential standoff in case of save failure.
 
     Args:
         *args (tuple): Positional arguments to pass to torch.save.
+        use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
         **kwargs (dict): Keyword arguments to pass to torch.save.
     """
     try:
-        import dill as pickle  # noqa
-    except ImportError:
+        assert use_dill
+        import dill as pickle
+    except (AssertionError, ImportError):
         import pickle
 
     if "pickle_module" not in kwargs:
-        kwargs["pickle_module"] = pickle  # noqa
+        kwargs["pickle_module"] = pickle
 
     for i in range(4):  # 3 retries
         try:
             return _torch_save(*args, **kwargs)
-        except RuntimeError:  # unable to save, possibly waiting for device to flush or anti-virus to finish scanning
+        except RuntimeError as e:  # unable to save, possibly waiting for device to flush or antivirus scan
             if i == 3:
-                raise
-            time.sleep((2**i) / 2)  # exponential standoff 0.5s, 1.0s, 2.0s
+                raise e
+            time.sleep((2**i) / 2)  # exponential standoff: 0.5s, 1.0s, 2.0s