mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
ultralytics 8.1.31
NCNN and CLIP updates (#9235)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
41c2d8d99f
commit
3c179f87cb
@ -19,6 +19,10 @@ keywords: Ultralytics, YOLO, Configuration, cfg2dict, handle_deprecation, merge_
|
||||
|
||||
<br><br>
|
||||
|
||||
## ::: ultralytics.cfg.check_cfg
|
||||
|
||||
<br><br>
|
||||
|
||||
## ::: ultralytics.cfg.get_save_dir
|
||||
|
||||
<br><br>
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.1.30"
|
||||
__version__ = "8.1.31"
|
||||
|
||||
from ultralytics.data.explorer.explorer import Explorer
|
||||
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
||||
|
@ -30,8 +30,8 @@ from ultralytics.utils import (
|
||||
)
|
||||
|
||||
# Define valid tasks and modes
|
||||
MODES = "train", "val", "predict", "export", "track", "benchmark"
|
||||
TASKS = "detect", "segment", "classify", "pose", "obb"
|
||||
MODES = {"train", "val", "predict", "export", "track", "benchmark"}
|
||||
TASKS = {"detect", "segment", "classify", "pose", "obb"}
|
||||
TASK2DATA = {
|
||||
"detect": "coco8.yaml",
|
||||
"segment": "coco8-seg.yaml",
|
||||
@ -93,8 +93,8 @@ CLI_HELP_MSG = f"""
|
||||
"""
|
||||
|
||||
# Define keys for arg type checks
|
||||
CFG_FLOAT_KEYS = "warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"
|
||||
CFG_FRACTION_KEYS = (
|
||||
CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"}
|
||||
CFG_FRACTION_KEYS = {
|
||||
"dropout",
|
||||
"iou",
|
||||
"lr0",
|
||||
@ -118,8 +118,8 @@ CFG_FRACTION_KEYS = (
|
||||
"conf",
|
||||
"iou",
|
||||
"fraction",
|
||||
) # fraction floats 0.0 - 1.0
|
||||
CFG_INT_KEYS = (
|
||||
} # fraction floats 0.0 - 1.0
|
||||
CFG_INT_KEYS = {
|
||||
"epochs",
|
||||
"patience",
|
||||
"batch",
|
||||
@ -133,8 +133,8 @@ CFG_INT_KEYS = (
|
||||
"workspace",
|
||||
"nbs",
|
||||
"save_period",
|
||||
)
|
||||
CFG_BOOL_KEYS = (
|
||||
}
|
||||
CFG_BOOL_KEYS = {
|
||||
"save",
|
||||
"exist_ok",
|
||||
"verbose",
|
||||
@ -169,7 +169,7 @@ CFG_BOOL_KEYS = (
|
||||
"nms",
|
||||
"profile",
|
||||
"multi_scale",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def cfg2dict(cfg):
|
||||
@ -219,33 +219,46 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
||||
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
|
||||
|
||||
# Type and Value checks
|
||||
check_cfg(cfg)
|
||||
|
||||
# Return instance
|
||||
return IterableSimpleNamespace(**cfg)
|
||||
|
||||
|
||||
def check_cfg(cfg, hard=True):
|
||||
"""Check Ultralytics configuration argument types and values."""
|
||||
for k, v in cfg.items():
|
||||
if v is not None: # None values may be from optional args
|
||||
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
|
||||
)
|
||||
elif k in CFG_FRACTION_KEYS:
|
||||
if not isinstance(v, (int, float)):
|
||||
if hard:
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
|
||||
)
|
||||
cfg[k] = float(v)
|
||||
elif k in CFG_FRACTION_KEYS:
|
||||
if not isinstance(v, (int, float)):
|
||||
if hard:
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
|
||||
)
|
||||
cfg[k] = float(v)
|
||||
if not (0.0 <= v <= 1.0):
|
||||
raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.")
|
||||
elif k in CFG_INT_KEYS and not isinstance(v, int):
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')"
|
||||
)
|
||||
if hard:
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')"
|
||||
)
|
||||
cfg[k] = int(v)
|
||||
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')"
|
||||
)
|
||||
|
||||
# Return instance
|
||||
return IterableSimpleNamespace(**cfg)
|
||||
if hard:
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')"
|
||||
)
|
||||
cfg[k] = bool(v)
|
||||
|
||||
|
||||
def get_save_dir(args, name=None):
|
||||
@ -464,10 +477,10 @@ def entrypoint(debug=""):
|
||||
overrides = {} # basic overrides, i.e. imgsz=320
|
||||
for a in merge_equals_args(args): # merge spaces around '=' sign
|
||||
if a.startswith("--"):
|
||||
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
|
||||
a = a[2:]
|
||||
if a.endswith(","):
|
||||
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
|
||||
a = a[:-1]
|
||||
if "=" in a:
|
||||
try:
|
||||
@ -504,7 +517,7 @@ def entrypoint(debug=""):
|
||||
mode = overrides.get("mode")
|
||||
if mode is None:
|
||||
mode = DEFAULT_CFG.mode or "predict"
|
||||
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
|
||||
elif mode not in MODES:
|
||||
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
|
||||
|
||||
@ -520,7 +533,7 @@ def entrypoint(debug=""):
|
||||
model = overrides.pop("model", DEFAULT_CFG.model)
|
||||
if model is None:
|
||||
model = "yolov8n.pt"
|
||||
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'model' argument is missing. Using default 'model={model}'.")
|
||||
overrides["model"] = model
|
||||
stem = Path(model).stem.lower()
|
||||
if "rtdetr" in stem: # guess architecture
|
||||
@ -554,15 +567,15 @@ def entrypoint(debug=""):
|
||||
# Mode
|
||||
if mode in ("predict", "track") and "source" not in overrides:
|
||||
overrides["source"] = DEFAULT_CFG.source or ASSETS
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.")
|
||||
elif mode in ("train", "val"):
|
||||
if "data" not in overrides and "resume" not in overrides:
|
||||
overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
||||
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.")
|
||||
elif mode == "export":
|
||||
if "format" not in overrides:
|
||||
overrides["format"] = DEFAULT_CFG.format or "torchscript"
|
||||
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'format' argument is missing. Using default 'format={overrides['format']}'.")
|
||||
|
||||
# Run command in python
|
||||
getattr(model, mode)(**overrides) # default args from model
|
||||
|
@ -129,7 +129,7 @@ def check_source(source):
|
||||
webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
|
||||
if isinstance(source, (str, int, Path)): # int for local usb camera
|
||||
source = str(source)
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS | VID_FORMATS)
|
||||
is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
|
||||
webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
|
||||
screenshot = source.lower() == "screen"
|
||||
|
@ -35,8 +35,8 @@ from ultralytics.utils.downloads import download, safe_download, unzip_file
|
||||
from ultralytics.utils.ops import segments2boxes
|
||||
|
||||
HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance."
|
||||
IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # image suffixes
|
||||
VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm" # video suffixes
|
||||
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes
|
||||
VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
|
||||
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
||||
|
||||
|
||||
|
@ -385,14 +385,12 @@ class Results(SimpleClass):
|
||||
BGR=True,
|
||||
)
|
||||
|
||||
def tojson(self, normalize=False):
|
||||
"""Convert the object to JSON format."""
|
||||
def summary(self, normalize=False):
|
||||
"""Convert the results to a summarized format."""
|
||||
if self.probs is not None:
|
||||
LOGGER.warning("Warning: Classify task do not support `tojson` yet.")
|
||||
LOGGER.warning("Warning: Classify task do not support `summary` and `tojson` yet.")
|
||||
return
|
||||
|
||||
import json
|
||||
|
||||
# Create list of detection dictionaries
|
||||
results = []
|
||||
data = self.boxes.data.cpu().tolist()
|
||||
@ -413,8 +411,13 @@ class Results(SimpleClass):
|
||||
result["keypoints"] = {"x": (x / w).tolist(), "y": (y / h).tolist(), "visible": visible.tolist()}
|
||||
results.append(result)
|
||||
|
||||
# Convert detections to JSON
|
||||
return json.dumps(results, indent=2)
|
||||
return results
|
||||
|
||||
def tojson(self, normalize=False):
|
||||
"""Convert the results to JSON format."""
|
||||
import json
|
||||
|
||||
return json.dumps(self.summary(normalize=normalize), indent=2)
|
||||
|
||||
|
||||
class Boxes(BaseTensor):
|
||||
|
@ -509,14 +509,9 @@ class AutoBackend(nn.Module):
|
||||
# NCNN
|
||||
elif self.ncnn:
|
||||
mat_in = self.pyncnn.Mat(im[0].cpu().numpy())
|
||||
ex = self.net.create_extractor()
|
||||
input_names, output_names = self.net.input_names(), self.net.output_names()
|
||||
ex.input(input_names[0], mat_in)
|
||||
y = []
|
||||
for output_name in output_names:
|
||||
mat_out = self.pyncnn.Mat()
|
||||
ex.extract(output_name, mat_out)
|
||||
y.append(np.array(mat_out)[None])
|
||||
with self.net.create_extractor() as ex:
|
||||
ex.input(self.net.input_names()[0], mat_in)
|
||||
y = [np.array(ex.extract(x)[1])[None] for x in self.net.output_names()]
|
||||
|
||||
# NVIDIA Triton Inference Server
|
||||
elif self.triton:
|
||||
|
@ -560,7 +560,8 @@ class WorldModel(DetectionModel):
|
||||
|
||||
def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True):
|
||||
"""Initialize YOLOv8 world model with given config and parameters."""
|
||||
self.txt_feats = torch.randn(1, nc or 80, 512) # placeholder
|
||||
self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder
|
||||
self.clip_model = None # CLIP model placeholder
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
def set_classes(self, text):
|
||||
@ -571,10 +572,11 @@ class WorldModel(DetectionModel):
|
||||
check_requirements("git+https://github.com/openai/CLIP.git")
|
||||
import clip
|
||||
|
||||
model, _ = clip.load("ViT-B/32")
|
||||
device = next(model.parameters()).device
|
||||
if not self.clip_model:
|
||||
self.clip_model = clip.load("ViT-B/32")[0]
|
||||
device = next(self.clip_model.parameters()).device
|
||||
text_token = clip.tokenize(text).to(device)
|
||||
txt_feats = model.encode_text(text_token).to(dtype=torch.float32)
|
||||
txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32)
|
||||
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
||||
self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach()
|
||||
self.model[-1].nc = len(text)
|
||||
@ -841,7 +843,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
||||
|
||||
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
||||
if m in (
|
||||
if m in {
|
||||
Classify,
|
||||
Conv,
|
||||
ConvTranspose,
|
||||
@ -867,7 +869,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
DWConvTranspose2d,
|
||||
C3x,
|
||||
RepC3,
|
||||
):
|
||||
}:
|
||||
c1, c2 = ch[f], args[0]
|
||||
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
||||
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
||||
@ -883,7 +885,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
n = 1
|
||||
elif m is AIFI:
|
||||
args = [ch[f], *args]
|
||||
elif m in (HGStem, HGBlock):
|
||||
elif m in {HGStem, HGBlock}:
|
||||
c1, cm, c2 = ch[f], args[0], args[1]
|
||||
args = [c1, cm, c2, *args[2:]]
|
||||
if m is HGBlock:
|
||||
@ -895,7 +897,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
args = [ch[f]]
|
||||
elif m is Concat:
|
||||
c2 = sum(ch[x] for x in f)
|
||||
elif m in (Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn):
|
||||
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn}:
|
||||
args.append([ch[x] for x in f])
|
||||
if m is Segment:
|
||||
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
||||
@ -978,7 +980,7 @@ def guess_model_task(model):
|
||||
def cfg2task(cfg):
|
||||
"""Guess from YAML dictionary."""
|
||||
m = cfg["head"][-1][-2].lower() # output module name
|
||||
if m in ("classify", "classifier", "cls", "fc"):
|
||||
if m in {"classify", "classifier", "cls", "fc"}:
|
||||
return "classify"
|
||||
if m == "detect":
|
||||
return "detect"
|
||||
|
Loading…
x
Reference in New Issue
Block a user