Add docstrings to new HUB functions (#7576)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-01-14 20:23:02 +01:00 committed by GitHub
parent d762496989
commit 2f11ab5e6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -68,7 +68,7 @@ class HUBTrainingSession:
self.model = self.client.model() # load empty model self.model = self.client.model() # load empty model
def load_model(self, model_id): def load_model(self, model_id):
# Initialize model """Loads an existing model from Ultralytics HUB using the provided model identifier."""
self.model = self.client.model(model_id) self.model = self.client.model(model_id)
if not self.model.data: # then model model does not exist if not self.model.data: # then model model does not exist
raise ValueError(emojis(f"❌ The specified HUB model does not exist")) # TODO: improve error handling raise ValueError(emojis(f"❌ The specified HUB model does not exist")) # TODO: improve error handling
@ -82,7 +82,7 @@ class HUBTrainingSession:
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
def create_model(self, model_args): def create_model(self, model_args):
# Initialize model """Initializes a HUB training session with the specified model identifier."""
payload = { payload = {
"config": { "config": {
"batchSize": model_args.get("batch", -1), "batchSize": model_args.get("batch", -1),
@ -168,6 +168,7 @@ class HUBTrainingSession:
return api_key, model_id, filename return api_key, model_id, filename
def _set_train_args(self, **kwargs): def _set_train_args(self, **kwargs):
"""Initializes training arguments and creates a model entry on the Ultralytics HUB."""
if self.model.is_trained(): if self.model.is_trained():
# Model is already trained # Model is already trained
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀")) raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
@ -179,6 +180,7 @@ class HUBTrainingSession:
else: else:
# Model has no saved weights # Model has no saved weights
def get_train_args(config): def get_train_args(config):
"""Parses an identifier to extract API key, model ID, and filename if applicable."""
return { return {
"batch": config["batchSize"], "batch": config["batchSize"],
"epochs": config["epochs"], "epochs": config["epochs"],
@ -213,6 +215,7 @@ class HUBTrainingSession:
**kwargs, **kwargs,
): ):
def retry_request(): def retry_request():
"""Attempts to call `request_func` with retries, timeout, and optional threading."""
t0 = time.time() # Record the start time for the timeout t0 = time.time() # Record the start time for the timeout
for i in range(retry + 1): for i in range(retry + 1):
if (time.time() - t0) > timeout: if (time.time() - t0) > timeout:
@ -254,7 +257,7 @@ class HUBTrainingSession:
return retry_request() return retry_request()
def _should_retry(self, status_code): def _should_retry(self, status_code):
# Status codes that trigger retries """Determines if a request should be retried based on the HTTP status code."""
retry_codes = { retry_codes = {
HTTPStatus.REQUEST_TIMEOUT, HTTPStatus.REQUEST_TIMEOUT,
HTTPStatus.BAD_GATEWAY, HTTPStatus.BAD_GATEWAY,