mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-22 21:04:21 +08:00
ultralytics 8.1.33
fix HUB model checks (#9153)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
fc6c66a4a4
commit
ec1d110689
@ -117,7 +117,7 @@ logging = [
|
||||
"dvclive>=2.12.0",
|
||||
]
|
||||
extra = [
|
||||
"hub-sdk>=0.0.2", # Ultralytics HUB
|
||||
"hub-sdk>=0.0.5", # Ultralytics HUB
|
||||
"ipython", # interactive notebook
|
||||
"albumentations>=1.0.3", # training augmentations
|
||||
"pycocotools>=2.0.7", # COCO mAP
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.1.32"
|
||||
__version__ = "8.1.33"
|
||||
|
||||
from ultralytics.data.explorer.explorer import Explorer
|
||||
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
||||
|
@ -119,30 +119,27 @@ class Model(nn.Module):
|
||||
self.metrics = None # validation/training metrics
|
||||
self.session = None # HUB session
|
||||
self.task = task # task type
|
||||
self.model_name = model = str(model).strip() # strip spaces
|
||||
model = str(model).strip()
|
||||
|
||||
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
||||
if self.is_hub_model(model):
|
||||
# Fetch model from HUB
|
||||
checks.check_requirements("hub-sdk>=0.0.5")
|
||||
checks.check_requirements("hub-sdk>=0.0.6")
|
||||
self.session = self._get_hub_session(model)
|
||||
model = self.session.model_file
|
||||
|
||||
# Check if Triton Server model
|
||||
elif self.is_triton_model(model):
|
||||
self.model = model
|
||||
self.model_name = self.model = model
|
||||
self.task = task
|
||||
return
|
||||
|
||||
# Load or create new YOLO model
|
||||
model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
if Path(model).suffix in (".yaml", ".yml"):
|
||||
self._new(model, task=task, verbose=verbose)
|
||||
else:
|
||||
self._load(model, task=task)
|
||||
|
||||
self.model_name = model
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
|
||||
@ -190,8 +187,8 @@ class Model(nn.Module):
|
||||
return any(
|
||||
(
|
||||
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
||||
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID
|
||||
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODELID
|
||||
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODEL
|
||||
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODEL
|
||||
)
|
||||
)
|
||||
|
||||
@ -215,6 +212,7 @@ class Model(nn.Module):
|
||||
# Below added to allow export from YAMLs
|
||||
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
|
||||
self.model.task = self.task
|
||||
self.model_name = cfg
|
||||
|
||||
def _load(self, weights: str, task=None) -> None:
|
||||
"""
|
||||
@ -224,19 +222,23 @@ class Model(nn.Module):
|
||||
weights (str): model checkpoint to be loaded
|
||||
task (str | None): model task
|
||||
"""
|
||||
suffix = Path(weights).suffix
|
||||
if suffix == ".pt":
|
||||
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
|
||||
weights = checks.check_file(weights) # automatically download and return local filename
|
||||
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
|
||||
if Path(weights).suffix == ".pt":
|
||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.task = self.model.args["task"]
|
||||
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
||||
self.ckpt_path = self.model.pt_path
|
||||
else:
|
||||
weights = checks.check_file(weights)
|
||||
weights = checks.check_file(weights) # runs in all cases, not redundant with above call
|
||||
self.model, self.ckpt = weights, None
|
||||
self.task = task or guess_model_task(weights)
|
||||
self.ckpt_path = weights
|
||||
self.overrides["model"] = weights
|
||||
self.overrides["task"] = self.task
|
||||
self.model_name = weights
|
||||
|
||||
def _check_is_pytorch_model(self) -> None:
|
||||
"""Raises TypeError is model is not a PyTorch model."""
|
||||
|
@ -23,7 +23,7 @@ def login(api_key: str = None, save=True) -> bool:
|
||||
Returns:
|
||||
(bool): True if authentication is successful, False otherwise.
|
||||
"""
|
||||
checks.check_requirements("hub-sdk>=0.0.2")
|
||||
checks.check_requirements("hub-sdk>=0.0.6")
|
||||
from hub_sdk import HUBClient
|
||||
|
||||
api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL
|
||||
|
@ -603,7 +603,7 @@ class AutoBackend(nn.Module):
|
||||
from ultralytics.engine.exporter import export_formats
|
||||
|
||||
sf = list(export_formats().Suffix) # export suffixes
|
||||
if not is_url(p, check=False) and not isinstance(p, str):
|
||||
if not is_url(p) and not isinstance(p, str):
|
||||
check_suffix(p, sf) # checks
|
||||
name = Path(p).name
|
||||
types = [s in name for s in sf]
|
||||
|
@ -315,7 +315,7 @@ def check_font(font="Arial.ttf"):
|
||||
|
||||
# Download to USER_CONFIG_DIR if missing
|
||||
url = f"https://ultralytics.com/assets/{name}"
|
||||
if downloads.is_url(url):
|
||||
if downloads.is_url(url, check=True):
|
||||
downloads.safe_download(url=url, file=file)
|
||||
return file
|
||||
|
||||
@ -498,7 +498,7 @@ def check_file(file, suffix="", download=True, hard=True):
|
||||
raise FileNotFoundError(f"'{file}' does not exist")
|
||||
elif len(files) > 1 and hard:
|
||||
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
|
||||
return files[0] if len(files) else [] # return file
|
||||
return files[0] if len(files) else [] if hard else file # return file
|
||||
|
||||
|
||||
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
|
||||
|
@ -33,7 +33,7 @@ GITHUB_ASSETS_NAMES = (
|
||||
GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES]
|
||||
|
||||
|
||||
def is_url(url, check=True):
|
||||
def is_url(url, check=False):
|
||||
"""
|
||||
Validates if the given string is a URL and optionally checks if the URL exists online.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user