diff --git a/pyproject.toml b/pyproject.toml index 48fbb11e..d42c3805 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,7 +117,7 @@ logging = [ "dvclive>=2.12.0", ] extra = [ - "hub-sdk>=0.0.2", # Ultralytics HUB + "hub-sdk>=0.0.5", # Ultralytics HUB "ipython", # interactive notebook "albumentations>=1.0.3", # training augmentations "pycocotools>=2.0.7", # COCO mAP diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 5673168c..9a5b1ee3 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.1.32" +__version__ = "8.1.33" from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index ace045f7..ef5c93c0 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -119,30 +119,27 @@ class Model(nn.Module): self.metrics = None # validation/training metrics self.session = None # HUB session self.task = task # task type - self.model_name = model = str(model).strip() # strip spaces + model = str(model).strip() # Check if Ultralytics HUB model from https://hub.ultralytics.com if self.is_hub_model(model): # Fetch model from HUB - checks.check_requirements("hub-sdk>=0.0.5") + checks.check_requirements("hub-sdk>=0.0.6") self.session = self._get_hub_session(model) model = self.session.model_file # Check if Triton Server model elif self.is_triton_model(model): - self.model = model + self.model_name = self.model = model self.task = task return # Load or create new YOLO model - model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt if Path(model).suffix in (".yaml", ".yml"): self._new(model, task=task, verbose=verbose) else: self._load(model, task=task) - self.model_name = model - def __call__( self, source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, @@ -190,8 +187,8 @@ class Model(nn.Module): return any( ( model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID - [len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID - len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODELID + [len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODEL + len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODEL ) ) @@ -215,6 +212,7 @@ class Model(nn.Module): # Below added to allow export from YAMLs self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) self.model.task = self.task + self.model_name = cfg def _load(self, weights: str, task=None) -> None: """ @@ -224,19 +222,23 @@ class Model(nn.Module): weights (str): model checkpoint to be loaded task (str | None): model task """ - suffix = Path(weights).suffix - if suffix == ".pt": + if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): + weights = checks.check_file(weights) # automatically download and return local filename + weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt + + if Path(weights).suffix == ".pt": self.model, self.ckpt = attempt_load_one_weight(weights) self.task = self.model.args["task"] self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) self.ckpt_path = self.model.pt_path else: - weights = checks.check_file(weights) + weights = checks.check_file(weights) # runs in all cases, not redundant with above call self.model, self.ckpt = weights, None self.task = task or guess_model_task(weights) self.ckpt_path = weights self.overrides["model"] = weights self.overrides["task"] = self.task + self.model_name = weights def _check_is_pytorch_model(self) -> None: """Raises TypeError is model is not a PyTorch model.""" diff --git a/ultralytics/hub/__init__.py b/ultralytics/hub/__init__.py index b46bea40..4ea2fff8 100644 --- a/ultralytics/hub/__init__.py +++ b/ultralytics/hub/__init__.py @@ -23,7 +23,7 @@ def login(api_key: str = None, save=True) -> bool: Returns: (bool): True if authentication is successful, False otherwise. """ - checks.check_requirements("hub-sdk>=0.0.2") + checks.check_requirements("hub-sdk>=0.0.6") from hub_sdk import HUBClient api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 89057be5..abd255c9 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -603,7 +603,7 @@ class AutoBackend(nn.Module): from ultralytics.engine.exporter import export_formats sf = list(export_formats().Suffix) # export suffixes - if not is_url(p, check=False) and not isinstance(p, str): + if not is_url(p) and not isinstance(p, str): check_suffix(p, sf) # checks name = Path(p).name types = [s in name for s in sf] diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index e584a5b7..5d908851 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -315,7 +315,7 @@ def check_font(font="Arial.ttf"): # Download to USER_CONFIG_DIR if missing url = f"https://ultralytics.com/assets/{name}" - if downloads.is_url(url): + if downloads.is_url(url, check=True): downloads.safe_download(url=url, file=file) return file @@ -498,7 +498,7 @@ def check_file(file, suffix="", download=True, hard=True): raise FileNotFoundError(f"'{file}' does not exist") elif len(files) > 1 and hard: raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") - return files[0] if len(files) else [] # return file + return files[0] if len(files) else [] if hard else file # return file def check_yaml(file, suffix=(".yaml", ".yml"), hard=True): diff --git a/ultralytics/utils/downloads.py b/ultralytics/utils/downloads.py index c9c33a85..6191ade2 100644 --- a/ultralytics/utils/downloads.py +++ b/ultralytics/utils/downloads.py @@ -33,7 +33,7 @@ GITHUB_ASSETS_NAMES = ( GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES] -def is_url(url, check=True): +def is_url(url, check=False): """ Validates if the given string is a URL and optionally checks if the URL exists online.