ultralytics 8.1.30 add advanced HUB train arguments (#9110)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Kalen Michael 2024-03-20 23:53:00 +01:00 committed by GitHub
parent a62cdab53a
commit 8617fcf32d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 19 deletions

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.29" __version__ = "8.1.30"
from ultralytics.data.explorer.explorer import Explorer from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

View File

@ -124,7 +124,7 @@ class Model(nn.Module):
# Check if Ultralytics HUB model from https://hub.ultralytics.com # Check if Ultralytics HUB model from https://hub.ultralytics.com
if self.is_hub_model(model): if self.is_hub_model(model):
# Fetch model from HUB # Fetch model from HUB
checks.check_requirements("hub-sdk>0.0.2") checks.check_requirements("hub-sdk>=0.0.5")
self.session = self._get_hub_session(model) self.session = self._get_hub_session(model)
model = self.session.model_file model = self.session.model_file

View File

@ -170,10 +170,19 @@ class HUBTrainingSession:
return api_key, model_id, filename return api_key, model_id, filename
def _set_train_args(self, **kwargs): def _set_train_args(self):
"""Initializes training arguments and creates a model entry on the Ultralytics HUB.""" """
Initializes training arguments and creates a model entry on the Ultralytics HUB.
This method sets up training arguments based on the model's state and updates them with any additional
arguments provided. It handles different states of the model, such as whether it's resumable, pretrained,
or requires specific file setup.
Raises:
ValueError: If the model is already trained, if required dataset information is missing, or if there are
issues with the provided training arguments.
"""
if self.model.is_trained(): if self.model.is_trained():
# Model is already trained
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀")) raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
if self.model.is_resumable(): if self.model.is_resumable():
@ -182,26 +191,16 @@ class HUBTrainingSession:
self.model_file = self.model.get_weights_url("last") self.model_file = self.model.get_weights_url("last")
else: else:
# Model has no saved weights # Model has no saved weights
def get_train_args(config): self.train_args = self.model.data.get("train_args") # new response
"""Parses an identifier to extract API key, model ID, and filename if applicable."""
return {
"batch": config["batchSize"],
"epochs": config["epochs"],
"imgsz": config["imageSize"],
"patience": config["patience"],
"device": config["device"],
"cache": config["cache"],
"data": self.model.get_dataset_url(),
}
self.train_args = get_train_args(self.model.data.get("config"))
# Set the model file as either a *.pt or *.yaml file # Set the model file as either a *.pt or *.yaml file
self.model_file = ( self.model_file = (
self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture() self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture()
) )
if not self.train_args.get("data"): if "data" not in self.train_args:
raise ValueError("Dataset may still be processing. Please wait a minute and try again.") # RF fix # RF bug - datasets are sometimes not exported
raise ValueError("Dataset may still be processing. Please wait a minute and try again.")
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
self.model_id = self.model.id self.model_id = self.model.id