mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-11-04 00:45:38 +08:00 
			
		
		
		
	ultralytics 8.1.31 NCNN and CLIP updates (#9235)
				
					
				
			Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
		
							parent
							
								
									41c2d8d99f
								
							
						
					
					
						commit
						3c179f87cb
					
				@ -19,6 +19,10 @@ keywords: Ultralytics, YOLO, Configuration, cfg2dict, handle_deprecation, merge_
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
<br><br>
 | 
					<br><br>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## ::: ultralytics.cfg.check_cfg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					<br><br>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## ::: ultralytics.cfg.get_save_dir
 | 
					## ::: ultralytics.cfg.get_save_dir
 | 
				
			||||||
 | 
					
 | 
				
			||||||
<br><br>
 | 
					<br><br>
 | 
				
			||||||
 | 
				
			|||||||
@ -1,6 +1,6 @@
 | 
				
			|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
 | 
					# Ultralytics YOLO 🚀, AGPL-3.0 license
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__version__ = "8.1.30"
 | 
					__version__ = "8.1.31"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ultralytics.data.explorer.explorer import Explorer
 | 
					from ultralytics.data.explorer.explorer import Explorer
 | 
				
			||||||
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
 | 
					from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
 | 
				
			||||||
 | 
				
			|||||||
@ -30,8 +30,8 @@ from ultralytics.utils import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Define valid tasks and modes
 | 
					# Define valid tasks and modes
 | 
				
			||||||
MODES = "train", "val", "predict", "export", "track", "benchmark"
 | 
					MODES = {"train", "val", "predict", "export", "track", "benchmark"}
 | 
				
			||||||
TASKS = "detect", "segment", "classify", "pose", "obb"
 | 
					TASKS = {"detect", "segment", "classify", "pose", "obb"}
 | 
				
			||||||
TASK2DATA = {
 | 
					TASK2DATA = {
 | 
				
			||||||
    "detect": "coco8.yaml",
 | 
					    "detect": "coco8.yaml",
 | 
				
			||||||
    "segment": "coco8-seg.yaml",
 | 
					    "segment": "coco8-seg.yaml",
 | 
				
			||||||
@ -93,8 +93,8 @@ CLI_HELP_MSG = f"""
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Define keys for arg type checks
 | 
					# Define keys for arg type checks
 | 
				
			||||||
CFG_FLOAT_KEYS = "warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"
 | 
					CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"}
 | 
				
			||||||
CFG_FRACTION_KEYS = (
 | 
					CFG_FRACTION_KEYS = {
 | 
				
			||||||
    "dropout",
 | 
					    "dropout",
 | 
				
			||||||
    "iou",
 | 
					    "iou",
 | 
				
			||||||
    "lr0",
 | 
					    "lr0",
 | 
				
			||||||
@ -118,8 +118,8 @@ CFG_FRACTION_KEYS = (
 | 
				
			|||||||
    "conf",
 | 
					    "conf",
 | 
				
			||||||
    "iou",
 | 
					    "iou",
 | 
				
			||||||
    "fraction",
 | 
					    "fraction",
 | 
				
			||||||
)  # fraction floats 0.0 - 1.0
 | 
					}  # fraction floats 0.0 - 1.0
 | 
				
			||||||
CFG_INT_KEYS = (
 | 
					CFG_INT_KEYS = {
 | 
				
			||||||
    "epochs",
 | 
					    "epochs",
 | 
				
			||||||
    "patience",
 | 
					    "patience",
 | 
				
			||||||
    "batch",
 | 
					    "batch",
 | 
				
			||||||
@ -133,8 +133,8 @@ CFG_INT_KEYS = (
 | 
				
			|||||||
    "workspace",
 | 
					    "workspace",
 | 
				
			||||||
    "nbs",
 | 
					    "nbs",
 | 
				
			||||||
    "save_period",
 | 
					    "save_period",
 | 
				
			||||||
)
 | 
					}
 | 
				
			||||||
CFG_BOOL_KEYS = (
 | 
					CFG_BOOL_KEYS = {
 | 
				
			||||||
    "save",
 | 
					    "save",
 | 
				
			||||||
    "exist_ok",
 | 
					    "exist_ok",
 | 
				
			||||||
    "verbose",
 | 
					    "verbose",
 | 
				
			||||||
@ -169,7 +169,7 @@ CFG_BOOL_KEYS = (
 | 
				
			|||||||
    "nms",
 | 
					    "nms",
 | 
				
			||||||
    "profile",
 | 
					    "profile",
 | 
				
			||||||
    "multi_scale",
 | 
					    "multi_scale",
 | 
				
			||||||
)
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def cfg2dict(cfg):
 | 
					def cfg2dict(cfg):
 | 
				
			||||||
@ -219,33 +219,46 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
 | 
				
			|||||||
        LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
 | 
					        LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Type and Value checks
 | 
					    # Type and Value checks
 | 
				
			||||||
 | 
					    check_cfg(cfg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Return instance
 | 
				
			||||||
 | 
					    return IterableSimpleNamespace(**cfg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def check_cfg(cfg, hard=True):
 | 
				
			||||||
 | 
					    """Check Ultralytics configuration argument types and values."""
 | 
				
			||||||
    for k, v in cfg.items():
 | 
					    for k, v in cfg.items():
 | 
				
			||||||
        if v is not None:  # None values may be from optional args
 | 
					        if v is not None:  # None values may be from optional args
 | 
				
			||||||
            if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
 | 
					            if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
 | 
				
			||||||
                raise TypeError(
 | 
					                if hard:
 | 
				
			||||||
                    f"'{k}={v}' is of invalid type {type(v).__name__}. "
 | 
					 | 
				
			||||||
                    f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
            elif k in CFG_FRACTION_KEYS:
 | 
					 | 
				
			||||||
                if not isinstance(v, (int, float)):
 | 
					 | 
				
			||||||
                    raise TypeError(
 | 
					                    raise TypeError(
 | 
				
			||||||
                        f"'{k}={v}' is of invalid type {type(v).__name__}. "
 | 
					                        f"'{k}={v}' is of invalid type {type(v).__name__}. "
 | 
				
			||||||
                        f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
 | 
					                        f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
					                cfg[k] = float(v)
 | 
				
			||||||
 | 
					            elif k in CFG_FRACTION_KEYS:
 | 
				
			||||||
 | 
					                if not isinstance(v, (int, float)):
 | 
				
			||||||
 | 
					                    if hard:
 | 
				
			||||||
 | 
					                        raise TypeError(
 | 
				
			||||||
 | 
					                            f"'{k}={v}' is of invalid type {type(v).__name__}. "
 | 
				
			||||||
 | 
					                            f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                    cfg[k] = float(v)
 | 
				
			||||||
                if not (0.0 <= v <= 1.0):
 | 
					                if not (0.0 <= v <= 1.0):
 | 
				
			||||||
                    raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.")
 | 
					                    raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.")
 | 
				
			||||||
            elif k in CFG_INT_KEYS and not isinstance(v, int):
 | 
					            elif k in CFG_INT_KEYS and not isinstance(v, int):
 | 
				
			||||||
                raise TypeError(
 | 
					                if hard:
 | 
				
			||||||
                    f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')"
 | 
					                    raise TypeError(
 | 
				
			||||||
                )
 | 
					                        f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')"
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                cfg[k] = int(v)
 | 
				
			||||||
            elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
 | 
					            elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
 | 
				
			||||||
                raise TypeError(
 | 
					                if hard:
 | 
				
			||||||
                    f"'{k}={v}' is of invalid type {type(v).__name__}. "
 | 
					                    raise TypeError(
 | 
				
			||||||
                    f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')"
 | 
					                        f"'{k}={v}' is of invalid type {type(v).__name__}. "
 | 
				
			||||||
                )
 | 
					                        f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')"
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
    # Return instance
 | 
					                cfg[k] = bool(v)
 | 
				
			||||||
    return IterableSimpleNamespace(**cfg)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_save_dir(args, name=None):
 | 
					def get_save_dir(args, name=None):
 | 
				
			||||||
@ -464,10 +477,10 @@ def entrypoint(debug=""):
 | 
				
			|||||||
    overrides = {}  # basic overrides, i.e. imgsz=320
 | 
					    overrides = {}  # basic overrides, i.e. imgsz=320
 | 
				
			||||||
    for a in merge_equals_args(args):  # merge spaces around '=' sign
 | 
					    for a in merge_equals_args(args):  # merge spaces around '=' sign
 | 
				
			||||||
        if a.startswith("--"):
 | 
					        if a.startswith("--"):
 | 
				
			||||||
            LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
 | 
					            LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
 | 
				
			||||||
            a = a[2:]
 | 
					            a = a[2:]
 | 
				
			||||||
        if a.endswith(","):
 | 
					        if a.endswith(","):
 | 
				
			||||||
            LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
 | 
					            LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
 | 
				
			||||||
            a = a[:-1]
 | 
					            a = a[:-1]
 | 
				
			||||||
        if "=" in a:
 | 
					        if "=" in a:
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
@ -504,7 +517,7 @@ def entrypoint(debug=""):
 | 
				
			|||||||
    mode = overrides.get("mode")
 | 
					    mode = overrides.get("mode")
 | 
				
			||||||
    if mode is None:
 | 
					    if mode is None:
 | 
				
			||||||
        mode = DEFAULT_CFG.mode or "predict"
 | 
					        mode = DEFAULT_CFG.mode or "predict"
 | 
				
			||||||
        LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
 | 
					        LOGGER.warning(f"WARNING ⚠️ 'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
 | 
				
			||||||
    elif mode not in MODES:
 | 
					    elif mode not in MODES:
 | 
				
			||||||
        raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
 | 
					        raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -520,7 +533,7 @@ def entrypoint(debug=""):
 | 
				
			|||||||
    model = overrides.pop("model", DEFAULT_CFG.model)
 | 
					    model = overrides.pop("model", DEFAULT_CFG.model)
 | 
				
			||||||
    if model is None:
 | 
					    if model is None:
 | 
				
			||||||
        model = "yolov8n.pt"
 | 
					        model = "yolov8n.pt"
 | 
				
			||||||
        LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
 | 
					        LOGGER.warning(f"WARNING ⚠️ 'model' argument is missing. Using default 'model={model}'.")
 | 
				
			||||||
    overrides["model"] = model
 | 
					    overrides["model"] = model
 | 
				
			||||||
    stem = Path(model).stem.lower()
 | 
					    stem = Path(model).stem.lower()
 | 
				
			||||||
    if "rtdetr" in stem:  # guess architecture
 | 
					    if "rtdetr" in stem:  # guess architecture
 | 
				
			||||||
@ -554,15 +567,15 @@ def entrypoint(debug=""):
 | 
				
			|||||||
    # Mode
 | 
					    # Mode
 | 
				
			||||||
    if mode in ("predict", "track") and "source" not in overrides:
 | 
					    if mode in ("predict", "track") and "source" not in overrides:
 | 
				
			||||||
        overrides["source"] = DEFAULT_CFG.source or ASSETS
 | 
					        overrides["source"] = DEFAULT_CFG.source or ASSETS
 | 
				
			||||||
        LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
 | 
					        LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.")
 | 
				
			||||||
    elif mode in ("train", "val"):
 | 
					    elif mode in ("train", "val"):
 | 
				
			||||||
        if "data" not in overrides and "resume" not in overrides:
 | 
					        if "data" not in overrides and "resume" not in overrides:
 | 
				
			||||||
            overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
 | 
					            overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
 | 
				
			||||||
            LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
 | 
					            LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.")
 | 
				
			||||||
    elif mode == "export":
 | 
					    elif mode == "export":
 | 
				
			||||||
        if "format" not in overrides:
 | 
					        if "format" not in overrides:
 | 
				
			||||||
            overrides["format"] = DEFAULT_CFG.format or "torchscript"
 | 
					            overrides["format"] = DEFAULT_CFG.format or "torchscript"
 | 
				
			||||||
            LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
 | 
					            LOGGER.warning(f"WARNING ⚠️ 'format' argument is missing. Using default 'format={overrides['format']}'.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Run command in python
 | 
					    # Run command in python
 | 
				
			||||||
    getattr(model, mode)(**overrides)  # default args from model
 | 
					    getattr(model, mode)(**overrides)  # default args from model
 | 
				
			||||||
 | 
				
			|||||||
@ -129,7 +129,7 @@ def check_source(source):
 | 
				
			|||||||
    webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
 | 
					    webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
 | 
				
			||||||
    if isinstance(source, (str, int, Path)):  # int for local usb camera
 | 
					    if isinstance(source, (str, int, Path)):  # int for local usb camera
 | 
				
			||||||
        source = str(source)
 | 
					        source = str(source)
 | 
				
			||||||
        is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
 | 
					        is_file = Path(source).suffix[1:] in (IMG_FORMATS | VID_FORMATS)
 | 
				
			||||||
        is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
 | 
					        is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
 | 
				
			||||||
        webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
 | 
					        webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
 | 
				
			||||||
        screenshot = source.lower() == "screen"
 | 
					        screenshot = source.lower() == "screen"
 | 
				
			||||||
 | 
				
			|||||||
@ -35,8 +35,8 @@ from ultralytics.utils.downloads import download, safe_download, unzip_file
 | 
				
			|||||||
from ultralytics.utils.ops import segments2boxes
 | 
					from ultralytics.utils.ops import segments2boxes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance."
 | 
					HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance."
 | 
				
			||||||
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"  # image suffixes
 | 
					IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"}  # image suffixes
 | 
				
			||||||
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"  # video suffixes
 | 
					VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"}  # video suffixes
 | 
				
			||||||
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true"  # global pin_memory for dataloaders
 | 
					PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true"  # global pin_memory for dataloaders
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -385,14 +385,12 @@ class Results(SimpleClass):
 | 
				
			|||||||
                BGR=True,
 | 
					                BGR=True,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def tojson(self, normalize=False):
 | 
					    def summary(self, normalize=False):
 | 
				
			||||||
        """Convert the object to JSON format."""
 | 
					        """Convert the results to a summarized format."""
 | 
				
			||||||
        if self.probs is not None:
 | 
					        if self.probs is not None:
 | 
				
			||||||
            LOGGER.warning("Warning: Classify task do not support `tojson` yet.")
 | 
					            LOGGER.warning("Warning: Classify task do not support `summary` and `tojson` yet.")
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        import json
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Create list of detection dictionaries
 | 
					        # Create list of detection dictionaries
 | 
				
			||||||
        results = []
 | 
					        results = []
 | 
				
			||||||
        data = self.boxes.data.cpu().tolist()
 | 
					        data = self.boxes.data.cpu().tolist()
 | 
				
			||||||
@ -413,8 +411,13 @@ class Results(SimpleClass):
 | 
				
			|||||||
                result["keypoints"] = {"x": (x / w).tolist(), "y": (y / h).tolist(), "visible": visible.tolist()}
 | 
					                result["keypoints"] = {"x": (x / w).tolist(), "y": (y / h).tolist(), "visible": visible.tolist()}
 | 
				
			||||||
            results.append(result)
 | 
					            results.append(result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Convert detections to JSON
 | 
					        return results
 | 
				
			||||||
        return json.dumps(results, indent=2)
 | 
					
 | 
				
			||||||
 | 
					    def tojson(self, normalize=False):
 | 
				
			||||||
 | 
					        """Convert the results to JSON format."""
 | 
				
			||||||
 | 
					        import json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return json.dumps(self.summary(normalize=normalize), indent=2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Boxes(BaseTensor):
 | 
					class Boxes(BaseTensor):
 | 
				
			||||||
 | 
				
			|||||||
@ -509,14 +509,9 @@ class AutoBackend(nn.Module):
 | 
				
			|||||||
        # NCNN
 | 
					        # NCNN
 | 
				
			||||||
        elif self.ncnn:
 | 
					        elif self.ncnn:
 | 
				
			||||||
            mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
 | 
					            mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
 | 
				
			||||||
            ex = self.net.create_extractor()
 | 
					            with self.net.create_extractor() as ex:
 | 
				
			||||||
            input_names, output_names = self.net.input_names(), self.net.output_names()
 | 
					                ex.input(self.net.input_names()[0], mat_in)
 | 
				
			||||||
            ex.input(input_names[0], mat_in)
 | 
					                y = [np.array(ex.extract(x)[1])[None] for x in self.net.output_names()]
 | 
				
			||||||
            y = []
 | 
					 | 
				
			||||||
            for output_name in output_names:
 | 
					 | 
				
			||||||
                mat_out = self.pyncnn.Mat()
 | 
					 | 
				
			||||||
                ex.extract(output_name, mat_out)
 | 
					 | 
				
			||||||
                y.append(np.array(mat_out)[None])
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # NVIDIA Triton Inference Server
 | 
					        # NVIDIA Triton Inference Server
 | 
				
			||||||
        elif self.triton:
 | 
					        elif self.triton:
 | 
				
			||||||
 | 
				
			|||||||
@ -560,7 +560,8 @@ class WorldModel(DetectionModel):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
 | 
					    def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
 | 
				
			||||||
        """Initialize YOLOv8 world model with given config and parameters."""
 | 
					        """Initialize YOLOv8 world model with given config and parameters."""
 | 
				
			||||||
        self.txt_feats = torch.randn(1, nc or 80, 512)  # placeholder
 | 
					        self.txt_feats = torch.randn(1, nc or 80, 512)  # features placeholder
 | 
				
			||||||
 | 
					        self.clip_model = None  # CLIP model placeholder
 | 
				
			||||||
        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
 | 
					        super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def set_classes(self, text):
 | 
					    def set_classes(self, text):
 | 
				
			||||||
@ -571,10 +572,11 @@ class WorldModel(DetectionModel):
 | 
				
			|||||||
            check_requirements("git+https://github.com/openai/CLIP.git")
 | 
					            check_requirements("git+https://github.com/openai/CLIP.git")
 | 
				
			||||||
            import clip
 | 
					            import clip
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        model, _ = clip.load("ViT-B/32")
 | 
					        if not self.clip_model:
 | 
				
			||||||
        device = next(model.parameters()).device
 | 
					            self.clip_model = clip.load("ViT-B/32")[0]
 | 
				
			||||||
 | 
					        device = next(self.clip_model.parameters()).device
 | 
				
			||||||
        text_token = clip.tokenize(text).to(device)
 | 
					        text_token = clip.tokenize(text).to(device)
 | 
				
			||||||
        txt_feats = model.encode_text(text_token).to(dtype=torch.float32)
 | 
					        txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32)
 | 
				
			||||||
        txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
 | 
					        txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
 | 
				
			||||||
        self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
 | 
					        self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
 | 
				
			||||||
        self.model[-1].nc = len(text)
 | 
					        self.model[-1].nc = len(text)
 | 
				
			||||||
@ -841,7 +843,7 @@ def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)
 | 
				
			|||||||
                    args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
 | 
					                    args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gain
 | 
					        n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gain
 | 
				
			||||||
        if m in (
 | 
					        if m in {
 | 
				
			||||||
            Classify,
 | 
					            Classify,
 | 
				
			||||||
            Conv,
 | 
					            Conv,
 | 
				
			||||||
            ConvTranspose,
 | 
					            ConvTranspose,
 | 
				
			||||||
@ -867,7 +869,7 @@ def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)
 | 
				
			|||||||
            DWConvTranspose2d,
 | 
					            DWConvTranspose2d,
 | 
				
			||||||
            C3x,
 | 
					            C3x,
 | 
				
			||||||
            RepC3,
 | 
					            RepC3,
 | 
				
			||||||
        ):
 | 
					        }:
 | 
				
			||||||
            c1, c2 = ch[f], args[0]
 | 
					            c1, c2 = ch[f], args[0]
 | 
				
			||||||
            if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
 | 
					            if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
 | 
				
			||||||
                c2 = make_divisible(min(c2, max_channels) * width, 8)
 | 
					                c2 = make_divisible(min(c2, max_channels) * width, 8)
 | 
				
			||||||
@ -883,7 +885,7 @@ def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)
 | 
				
			|||||||
                n = 1
 | 
					                n = 1
 | 
				
			||||||
        elif m is AIFI:
 | 
					        elif m is AIFI:
 | 
				
			||||||
            args = [ch[f], *args]
 | 
					            args = [ch[f], *args]
 | 
				
			||||||
        elif m in (HGStem, HGBlock):
 | 
					        elif m in {HGStem, HGBlock}:
 | 
				
			||||||
            c1, cm, c2 = ch[f], args[0], args[1]
 | 
					            c1, cm, c2 = ch[f], args[0], args[1]
 | 
				
			||||||
            args = [c1, cm, c2, *args[2:]]
 | 
					            args = [c1, cm, c2, *args[2:]]
 | 
				
			||||||
            if m is HGBlock:
 | 
					            if m is HGBlock:
 | 
				
			||||||
@ -895,7 +897,7 @@ def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)
 | 
				
			|||||||
            args = [ch[f]]
 | 
					            args = [ch[f]]
 | 
				
			||||||
        elif m is Concat:
 | 
					        elif m is Concat:
 | 
				
			||||||
            c2 = sum(ch[x] for x in f)
 | 
					            c2 = sum(ch[x] for x in f)
 | 
				
			||||||
        elif m in (Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn):
 | 
					        elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn}:
 | 
				
			||||||
            args.append([ch[x] for x in f])
 | 
					            args.append([ch[x] for x in f])
 | 
				
			||||||
            if m is Segment:
 | 
					            if m is Segment:
 | 
				
			||||||
                args[2] = make_divisible(min(args[2], max_channels) * width, 8)
 | 
					                args[2] = make_divisible(min(args[2], max_channels) * width, 8)
 | 
				
			||||||
@ -978,7 +980,7 @@ def guess_model_task(model):
 | 
				
			|||||||
    def cfg2task(cfg):
 | 
					    def cfg2task(cfg):
 | 
				
			||||||
        """Guess from YAML dictionary."""
 | 
					        """Guess from YAML dictionary."""
 | 
				
			||||||
        m = cfg["head"][-1][-2].lower()  # output module name
 | 
					        m = cfg["head"][-1][-2].lower()  # output module name
 | 
				
			||||||
        if m in ("classify", "classifier", "cls", "fc"):
 | 
					        if m in {"classify", "classifier", "cls", "fc"}:
 | 
				
			||||||
            return "classify"
 | 
					            return "classify"
 | 
				
			||||||
        if m == "detect":
 | 
					        if m == "detect":
 | 
				
			||||||
            return "detect"
 | 
					            return "detect"
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user