mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
ultralytics 8.1.33
fix HUB model checks (#9153)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
fc6c66a4a4
commit
ec1d110689
@ -117,7 +117,7 @@ logging = [
|
|||||||
"dvclive>=2.12.0",
|
"dvclive>=2.12.0",
|
||||||
]
|
]
|
||||||
extra = [
|
extra = [
|
||||||
"hub-sdk>=0.0.2", # Ultralytics HUB
|
"hub-sdk>=0.0.5", # Ultralytics HUB
|
||||||
"ipython", # interactive notebook
|
"ipython", # interactive notebook
|
||||||
"albumentations>=1.0.3", # training augmentations
|
"albumentations>=1.0.3", # training augmentations
|
||||||
"pycocotools>=2.0.7", # COCO mAP
|
"pycocotools>=2.0.7", # COCO mAP
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.1.32"
|
__version__ = "8.1.33"
|
||||||
|
|
||||||
from ultralytics.data.explorer.explorer import Explorer
|
from ultralytics.data.explorer.explorer import Explorer
|
||||||
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
||||||
|
@ -119,30 +119,27 @@ class Model(nn.Module):
|
|||||||
self.metrics = None # validation/training metrics
|
self.metrics = None # validation/training metrics
|
||||||
self.session = None # HUB session
|
self.session = None # HUB session
|
||||||
self.task = task # task type
|
self.task = task # task type
|
||||||
self.model_name = model = str(model).strip() # strip spaces
|
model = str(model).strip()
|
||||||
|
|
||||||
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
||||||
if self.is_hub_model(model):
|
if self.is_hub_model(model):
|
||||||
# Fetch model from HUB
|
# Fetch model from HUB
|
||||||
checks.check_requirements("hub-sdk>=0.0.5")
|
checks.check_requirements("hub-sdk>=0.0.6")
|
||||||
self.session = self._get_hub_session(model)
|
self.session = self._get_hub_session(model)
|
||||||
model = self.session.model_file
|
model = self.session.model_file
|
||||||
|
|
||||||
# Check if Triton Server model
|
# Check if Triton Server model
|
||||||
elif self.is_triton_model(model):
|
elif self.is_triton_model(model):
|
||||||
self.model = model
|
self.model_name = self.model = model
|
||||||
self.task = task
|
self.task = task
|
||||||
return
|
return
|
||||||
|
|
||||||
# Load or create new YOLO model
|
# Load or create new YOLO model
|
||||||
model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
|
||||||
if Path(model).suffix in (".yaml", ".yml"):
|
if Path(model).suffix in (".yaml", ".yml"):
|
||||||
self._new(model, task=task, verbose=verbose)
|
self._new(model, task=task, verbose=verbose)
|
||||||
else:
|
else:
|
||||||
self._load(model, task=task)
|
self._load(model, task=task)
|
||||||
|
|
||||||
self.model_name = model
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
|
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
|
||||||
@ -190,8 +187,8 @@ class Model(nn.Module):
|
|||||||
return any(
|
return any(
|
||||||
(
|
(
|
||||||
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
||||||
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID
|
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODEL
|
||||||
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODELID
|
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODEL
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -215,6 +212,7 @@ class Model(nn.Module):
|
|||||||
# Below added to allow export from YAMLs
|
# Below added to allow export from YAMLs
|
||||||
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
|
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
|
||||||
self.model.task = self.task
|
self.model.task = self.task
|
||||||
|
self.model_name = cfg
|
||||||
|
|
||||||
def _load(self, weights: str, task=None) -> None:
|
def _load(self, weights: str, task=None) -> None:
|
||||||
"""
|
"""
|
||||||
@ -224,19 +222,23 @@ class Model(nn.Module):
|
|||||||
weights (str): model checkpoint to be loaded
|
weights (str): model checkpoint to be loaded
|
||||||
task (str | None): model task
|
task (str | None): model task
|
||||||
"""
|
"""
|
||||||
suffix = Path(weights).suffix
|
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
|
||||||
if suffix == ".pt":
|
weights = checks.check_file(weights) # automatically download and return local filename
|
||||||
|
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||||
|
|
||||||
|
if Path(weights).suffix == ".pt":
|
||||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||||
self.task = self.model.args["task"]
|
self.task = self.model.args["task"]
|
||||||
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
||||||
self.ckpt_path = self.model.pt_path
|
self.ckpt_path = self.model.pt_path
|
||||||
else:
|
else:
|
||||||
weights = checks.check_file(weights)
|
weights = checks.check_file(weights) # runs in all cases, not redundant with above call
|
||||||
self.model, self.ckpt = weights, None
|
self.model, self.ckpt = weights, None
|
||||||
self.task = task or guess_model_task(weights)
|
self.task = task or guess_model_task(weights)
|
||||||
self.ckpt_path = weights
|
self.ckpt_path = weights
|
||||||
self.overrides["model"] = weights
|
self.overrides["model"] = weights
|
||||||
self.overrides["task"] = self.task
|
self.overrides["task"] = self.task
|
||||||
|
self.model_name = weights
|
||||||
|
|
||||||
def _check_is_pytorch_model(self) -> None:
|
def _check_is_pytorch_model(self) -> None:
|
||||||
"""Raises TypeError is model is not a PyTorch model."""
|
"""Raises TypeError is model is not a PyTorch model."""
|
||||||
|
@ -23,7 +23,7 @@ def login(api_key: str = None, save=True) -> bool:
|
|||||||
Returns:
|
Returns:
|
||||||
(bool): True if authentication is successful, False otherwise.
|
(bool): True if authentication is successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
checks.check_requirements("hub-sdk>=0.0.2")
|
checks.check_requirements("hub-sdk>=0.0.6")
|
||||||
from hub_sdk import HUBClient
|
from hub_sdk import HUBClient
|
||||||
|
|
||||||
api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL
|
api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL
|
||||||
|
@ -603,7 +603,7 @@ class AutoBackend(nn.Module):
|
|||||||
from ultralytics.engine.exporter import export_formats
|
from ultralytics.engine.exporter import export_formats
|
||||||
|
|
||||||
sf = list(export_formats().Suffix) # export suffixes
|
sf = list(export_formats().Suffix) # export suffixes
|
||||||
if not is_url(p, check=False) and not isinstance(p, str):
|
if not is_url(p) and not isinstance(p, str):
|
||||||
check_suffix(p, sf) # checks
|
check_suffix(p, sf) # checks
|
||||||
name = Path(p).name
|
name = Path(p).name
|
||||||
types = [s in name for s in sf]
|
types = [s in name for s in sf]
|
||||||
|
@ -315,7 +315,7 @@ def check_font(font="Arial.ttf"):
|
|||||||
|
|
||||||
# Download to USER_CONFIG_DIR if missing
|
# Download to USER_CONFIG_DIR if missing
|
||||||
url = f"https://ultralytics.com/assets/{name}"
|
url = f"https://ultralytics.com/assets/{name}"
|
||||||
if downloads.is_url(url):
|
if downloads.is_url(url, check=True):
|
||||||
downloads.safe_download(url=url, file=file)
|
downloads.safe_download(url=url, file=file)
|
||||||
return file
|
return file
|
||||||
|
|
||||||
@ -498,7 +498,7 @@ def check_file(file, suffix="", download=True, hard=True):
|
|||||||
raise FileNotFoundError(f"'{file}' does not exist")
|
raise FileNotFoundError(f"'{file}' does not exist")
|
||||||
elif len(files) > 1 and hard:
|
elif len(files) > 1 and hard:
|
||||||
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
|
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
|
||||||
return files[0] if len(files) else [] # return file
|
return files[0] if len(files) else [] if hard else file # return file
|
||||||
|
|
||||||
|
|
||||||
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
|
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
|
||||||
|
@ -33,7 +33,7 @@ GITHUB_ASSETS_NAMES = (
|
|||||||
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]
|
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]
|
||||||
|
|
||||||
|
|
||||||
def is_url(url, check=True):
|
def is_url(url, check=False):
|
||||||
"""
|
"""
|
||||||
Validates if the given string is a URL and optionally checks if the URL exists online.
|
Validates if the given string is a URL and optionally checks if the URL exists online.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user