mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Add docstrings to new HUB functions (#7576)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
d762496989
commit
2f11ab5e6f
@ -68,7 +68,7 @@ class HUBTrainingSession:
|
|||||||
self.model = self.client.model() # load empty model
|
self.model = self.client.model() # load empty model
|
||||||
|
|
||||||
def load_model(self, model_id):
|
def load_model(self, model_id):
|
||||||
# Initialize model
|
"""Loads an existing model from Ultralytics HUB using the provided model identifier."""
|
||||||
self.model = self.client.model(model_id)
|
self.model = self.client.model(model_id)
|
||||||
if not self.model.data: # then model model does not exist
|
if not self.model.data: # then model model does not exist
|
||||||
raise ValueError(emojis(f"❌ The specified HUB model does not exist")) # TODO: improve error handling
|
raise ValueError(emojis(f"❌ The specified HUB model does not exist")) # TODO: improve error handling
|
||||||
@ -82,7 +82,7 @@ class HUBTrainingSession:
|
|||||||
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀")
|
||||||
|
|
||||||
def create_model(self, model_args):
|
def create_model(self, model_args):
|
||||||
# Initialize model
|
"""Initializes a HUB training session with the specified model identifier."""
|
||||||
payload = {
|
payload = {
|
||||||
"config": {
|
"config": {
|
||||||
"batchSize": model_args.get("batch", -1),
|
"batchSize": model_args.get("batch", -1),
|
||||||
@ -168,6 +168,7 @@ class HUBTrainingSession:
|
|||||||
return api_key, model_id, filename
|
return api_key, model_id, filename
|
||||||
|
|
||||||
def _set_train_args(self, **kwargs):
|
def _set_train_args(self, **kwargs):
|
||||||
|
"""Initializes training arguments and creates a model entry on the Ultralytics HUB."""
|
||||||
if self.model.is_trained():
|
if self.model.is_trained():
|
||||||
# Model is already trained
|
# Model is already trained
|
||||||
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
|
raise ValueError(emojis(f"Model is already trained and uploaded to {self.model_url} 🚀"))
|
||||||
@ -179,6 +180,7 @@ class HUBTrainingSession:
|
|||||||
else:
|
else:
|
||||||
# Model has no saved weights
|
# Model has no saved weights
|
||||||
def get_train_args(config):
|
def get_train_args(config):
|
||||||
|
"""Parses an identifier to extract API key, model ID, and filename if applicable."""
|
||||||
return {
|
return {
|
||||||
"batch": config["batchSize"],
|
"batch": config["batchSize"],
|
||||||
"epochs": config["epochs"],
|
"epochs": config["epochs"],
|
||||||
@ -213,6 +215,7 @@ class HUBTrainingSession:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
def retry_request():
|
def retry_request():
|
||||||
|
"""Attempts to call `request_func` with retries, timeout, and optional threading."""
|
||||||
t0 = time.time() # Record the start time for the timeout
|
t0 = time.time() # Record the start time for the timeout
|
||||||
for i in range(retry + 1):
|
for i in range(retry + 1):
|
||||||
if (time.time() - t0) > timeout:
|
if (time.time() - t0) > timeout:
|
||||||
@ -254,7 +257,7 @@ class HUBTrainingSession:
|
|||||||
return retry_request()
|
return retry_request()
|
||||||
|
|
||||||
def _should_retry(self, status_code):
|
def _should_retry(self, status_code):
|
||||||
# Status codes that trigger retries
|
"""Determines if a request should be retried based on the HTTP status code."""
|
||||||
retry_codes = {
|
retry_codes = {
|
||||||
HTTPStatus.REQUEST_TIMEOUT,
|
HTTPStatus.REQUEST_TIMEOUT,
|
||||||
HTTPStatus.BAD_GATEWAY,
|
HTTPStatus.BAD_GATEWAY,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user