mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Feature: Create HUB Models from CLI or Python Script (#7316)
Co-authored-by: Hassaan Farooq <103611273+hassaanfarooq01@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
a92adf8231
commit
b54055a2c7
@ -78,6 +78,7 @@ dependencies = [
|
|||||||
"thop>=0.1.1", # FLOPs computation
|
"thop>=0.1.1", # FLOPs computation
|
||||||
"pandas>=1.1.4",
|
"pandas>=1.1.4",
|
||||||
"seaborn>=0.11.0", # plotting
|
"seaborn>=0.11.0", # plotting
|
||||||
|
"hub-sdk>=0.0.2", # Ultralytics HUB
|
||||||
]
|
]
|
||||||
|
|
||||||
# Optional dependencies ------------------------------------------------------------------------------------------------
|
# Optional dependencies ------------------------------------------------------------------------------------------------
|
||||||
|
@ -5,10 +5,11 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
from hub_sdk.config import HUB_WEB_ROOT
|
||||||
|
|
||||||
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
||||||
from ultralytics.hub.utils import HUB_WEB_ROOT
|
|
||||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||||
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, checks, emojis, yaml_load
|
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -76,8 +77,8 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
||||||
if self.is_hub_model(model):
|
if self.is_hub_model(model):
|
||||||
from ultralytics.hub.session import HUBTrainingSession
|
# Fetch model from HUB
|
||||||
self.session = HUBTrainingSession(model)
|
self.session = self._get_hub_session(model)
|
||||||
model = self.session.model_file
|
model = self.session.model_file
|
||||||
|
|
||||||
# Check if Triton Server model
|
# Check if Triton Server model
|
||||||
@ -93,10 +94,20 @@ class Model(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self._load(model, task)
|
self._load(model, task)
|
||||||
|
|
||||||
|
self.model_name = model
|
||||||
|
|
||||||
def __call__(self, source=None, stream=False, **kwargs):
|
def __call__(self, source=None, stream=False, **kwargs):
|
||||||
"""Calls the predict() method with given arguments to perform object detection."""
|
"""Calls the predict() method with given arguments to perform object detection."""
|
||||||
return self.predict(source, stream, **kwargs)
|
return self.predict(source, stream, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_hub_session(model: str):
|
||||||
|
"""Creates a session for Hub Training."""
|
||||||
|
from ultralytics.hub.session import HUBTrainingSession
|
||||||
|
|
||||||
|
session = HUBTrainingSession(model)
|
||||||
|
return session if session.client.authenticated else None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_triton_model(model):
|
def is_triton_model(model):
|
||||||
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
|
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
|
||||||
@ -336,10 +347,11 @@ class Model(nn.Module):
|
|||||||
**kwargs (Any): Any number of arguments representing the training configuration.
|
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||||
"""
|
"""
|
||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
if self.session: # Ultralytics HUB session
|
if hasattr(self.session, 'model') and self.session.model.id: # Ultralytics HUB session with loaded model
|
||||||
if any(kwargs):
|
if any(kwargs):
|
||||||
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
|
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
|
||||||
kwargs = self.session.train_args
|
kwargs = self.session.train_args # overwrite kwargs
|
||||||
|
|
||||||
checks.check_pip_update_available()
|
checks.check_pip_update_available()
|
||||||
|
|
||||||
overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
|
overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
|
||||||
@ -352,6 +364,20 @@ class Model(nn.Module):
|
|||||||
if not args.get('resume'): # manually set model only if not resuming
|
if not args.get('resume'): # manually set model only if not resuming
|
||||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||||
self.model = self.trainer.model
|
self.model = self.trainer.model
|
||||||
|
|
||||||
|
if SETTINGS['hub'] is True and not self.session:
|
||||||
|
# Create a model in HUB
|
||||||
|
try:
|
||||||
|
self.session = self._get_hub_session(self.model_name)
|
||||||
|
if self.session:
|
||||||
|
self.session.create_model(args)
|
||||||
|
# Check model was created
|
||||||
|
if not getattr(self.session.model, 'id', None):
|
||||||
|
self.session = None
|
||||||
|
except PermissionError:
|
||||||
|
# Ignore permission error
|
||||||
|
pass
|
||||||
|
|
||||||
self.trainer.hub_session = self.session # attach optional HUB session
|
self.trainer.hub_session = self.session # attach optional HUB session
|
||||||
self.trainer.train()
|
self.trainer.train()
|
||||||
# Update model and cfg after training
|
# Update model and cfg after training
|
||||||
|
@ -1,28 +1,49 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from hub_sdk import HUB_API_ROOT, HUB_WEB_ROOT, HUBClient
|
||||||
|
|
||||||
from ultralytics.data.utils import HUBDatasetStats
|
from ultralytics.data.utils import HUBDatasetStats
|
||||||
from ultralytics.hub.auth import Auth
|
from ultralytics.hub.auth import Auth
|
||||||
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
|
from ultralytics.hub.utils import PREFIX
|
||||||
from ultralytics.utils import LOGGER, SETTINGS
|
from ultralytics.utils import LOGGER, SETTINGS
|
||||||
|
|
||||||
|
|
||||||
def login(api_key=''):
|
def login(api_key: str = None, save=True) -> bool:
|
||||||
"""
|
"""
|
||||||
Log in to the Ultralytics HUB API using the provided API key.
|
Log in to the Ultralytics HUB API using the provided API key.
|
||||||
|
|
||||||
|
The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY environment variable if successfully authenticated.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
|
api_key (str, optional): The API key to use for authentication. If not provided, it will be retrieved from SETTINGS or HUB_API_KEY environment variable.
|
||||||
|
save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
|
||||||
Example:
|
Returns:
|
||||||
```python
|
bool: True if authentication is successful, False otherwise.
|
||||||
from ultralytics import hub
|
|
||||||
|
|
||||||
hub.login('API_KEY')
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
Auth(api_key, verbose=True)
|
api_key_url = f'{HUB_WEB_ROOT}/settings?tab=api+keys' # Set the redirect URL
|
||||||
|
saved_key = SETTINGS.get('api_key')
|
||||||
|
active_key = api_key or saved_key
|
||||||
|
credentials = {'api_key': active_key} if active_key and active_key != '' else None # Set credentials
|
||||||
|
|
||||||
|
client = HUBClient(credentials) # initialize HUBClient
|
||||||
|
|
||||||
|
if client.authenticated:
|
||||||
|
# Successfully authenticated with HUB
|
||||||
|
|
||||||
|
if save and client.api_key != saved_key:
|
||||||
|
SETTINGS.update({'api_key': client.api_key}) # update settings with valid API key
|
||||||
|
|
||||||
|
# Set message based on whether key was provided or retrieved from settings
|
||||||
|
log_message = ('New authentication successful ✅'
|
||||||
|
if client.api_key == api_key or not credentials else 'Authenticated ✅')
|
||||||
|
LOGGER.info(f'{PREFIX}{log_message}')
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# Failed to authenticate with HUB
|
||||||
|
LOGGER.info(f'{PREFIX}Retrieve API key from {api_key_url}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def logout():
|
def logout():
|
||||||
@ -43,7 +64,7 @@ def logout():
|
|||||||
|
|
||||||
def reset_model(model_id=''):
|
def reset_model(model_id=''):
|
||||||
"""Reset a trained model to an untrained state."""
|
"""Reset a trained model to an untrained state."""
|
||||||
r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
|
r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'modelId': model_id}, headers={'x-api-key': Auth().api_key})
|
||||||
if r.status_code == 200:
|
if r.status_code == 200:
|
||||||
LOGGER.info(f'{PREFIX}Model reset successfully')
|
LOGGER.info(f'{PREFIX}Model reset successfully')
|
||||||
return
|
return
|
||||||
@ -73,7 +94,8 @@ def get_export(model_id='', format='torchscript'):
|
|||||||
json={
|
json={
|
||||||
'apiKey': Auth().api_key,
|
'apiKey': Auth().api_key,
|
||||||
'modelId': model_id,
|
'modelId': model_id,
|
||||||
'format': format})
|
'format': format},
|
||||||
|
headers={'x-api-key': Auth().api_key})
|
||||||
assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
|
assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
|
||||||
return r.json()
|
return r.json()
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from hub_sdk import HUB_API_ROOT, HUB_WEB_ROOT
|
||||||
|
|
||||||
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials
|
from ultralytics.hub.utils import PREFIX, request_with_credentials
|
||||||
from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab
|
from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab
|
||||||
|
|
||||||
API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys'
|
API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys'
|
||||||
|
@ -1,17 +1,18 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
import signal
|
import threading
|
||||||
import sys
|
import time
|
||||||
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from time import sleep
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from hub_sdk import HUB_WEB_ROOT, HUBClient
|
||||||
|
|
||||||
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, smart_request
|
from ultralytics.hub.utils import HELP_MSG, PREFIX, TQDM
|
||||||
from ultralytics.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
|
from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab
|
||||||
from ultralytics.utils.errors import HUBModelError
|
from ultralytics.utils.errors import HUBModelError
|
||||||
|
|
||||||
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
|
AGENT_NAME = (f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local')
|
||||||
|
|
||||||
|
|
||||||
class HUBTrainingSession:
|
class HUBTrainingSession:
|
||||||
@ -34,7 +35,7 @@ class HUBTrainingSession:
|
|||||||
alive (bool): Indicates if the heartbeat loop is active.
|
alive (bool): Indicates if the heartbeat loop is active.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, url):
|
def __init__(self, identifier):
|
||||||
"""
|
"""
|
||||||
Initialize the HUBTrainingSession with the provided model identifier.
|
Initialize the HUBTrainingSession with the provided model identifier.
|
||||||
|
|
||||||
@ -46,98 +47,251 @@ class HUBTrainingSession:
|
|||||||
ValueError: If the provided model identifier is invalid.
|
ValueError: If the provided model identifier is invalid.
|
||||||
ConnectionError: If connecting with global API key is not supported.
|
ConnectionError: If connecting with global API key is not supported.
|
||||||
"""
|
"""
|
||||||
|
self.rate_limits = {
|
||||||
from ultralytics.hub.auth import Auth
|
'metrics': 3.0,
|
||||||
|
'ckpt': 900.0,
|
||||||
|
'heartbeat': 300.0, } # rate limits (seconds)
|
||||||
|
self.metrics_queue = {} # holds metrics for each epoch until upload
|
||||||
|
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
|
||||||
|
|
||||||
# Parse input
|
# Parse input
|
||||||
if url.startswith(f'{HUB_WEB_ROOT}/models/'):
|
api_key, model_id, self.filename = self._parse_identifier(identifier)
|
||||||
url = url.split(f'{HUB_WEB_ROOT}/models/')[-1]
|
|
||||||
if [len(x) for x in url.split('_')] == [42, 20]:
|
|
||||||
key, model_id = url.split('_')
|
|
||||||
elif len(url) == 20:
|
|
||||||
key, model_id = '', url
|
|
||||||
else:
|
|
||||||
raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. "
|
|
||||||
f"model='{HUB_WEB_ROOT}/models/MODEL_ID' and try again.")
|
|
||||||
|
|
||||||
# Authorize
|
# Get credentials
|
||||||
auth = Auth(key)
|
active_key = api_key or SETTINGS.get('api_key')
|
||||||
self.agent_id = None # identifies which instance is communicating with server
|
credentials = {'api_key': active_key} if active_key else None # set credentials
|
||||||
self.model_id = model_id
|
|
||||||
self.model_url = f'{HUB_WEB_ROOT}/models/{model_id}'
|
# Initialize client
|
||||||
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
|
self.client = HUBClient(credentials)
|
||||||
self.auth_header = auth.get_auth_header()
|
|
||||||
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
|
if model_id:
|
||||||
self.timers = {} # rate limit timers (seconds)
|
self.load_model(model_id) # load existing model
|
||||||
self.metrics_queue = {} # metrics queue
|
else:
|
||||||
self.model = self._get_model()
|
self.model = self.client.model() # load empty model
|
||||||
self.alive = True
|
|
||||||
self._start_heartbeat() # start heartbeats
|
def load_model(self, model_id):
|
||||||
self._register_signal_handlers()
|
# Initialize model
|
||||||
|
self.model = self.client.model(model_id)
|
||||||
|
self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}'
|
||||||
|
|
||||||
|
self._set_train_args()
|
||||||
|
|
||||||
|
# Start heartbeats for HUB to monitor agent
|
||||||
|
self.model.start_heartbeat(self.rate_limits['heartbeat'])
|
||||||
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
|
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
|
||||||
|
|
||||||
def _register_signal_handlers(self):
|
def create_model(self, model_args):
|
||||||
"""Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination."""
|
# Initialize model
|
||||||
signal.signal(signal.SIGTERM, self._handle_signal)
|
payload = {
|
||||||
signal.signal(signal.SIGINT, self._handle_signal)
|
'config': {
|
||||||
|
'batchSize': model_args.get('batch', -1),
|
||||||
|
'epochs': model_args.get('epochs', 300),
|
||||||
|
'imageSize': model_args.get('imgsz', 640),
|
||||||
|
'patience': model_args.get('patience', 100),
|
||||||
|
'device': model_args.get('device', ''),
|
||||||
|
'cache': model_args.get('cache', 'ram'), },
|
||||||
|
'dataset': {
|
||||||
|
'name': model_args.get('data')},
|
||||||
|
'lineage': {
|
||||||
|
'architecture': {
|
||||||
|
'name': self.filename.replace('.pt', '').replace('.yaml', ''), },
|
||||||
|
'parent': {}, },
|
||||||
|
'meta': {
|
||||||
|
'name': self.filename}, }
|
||||||
|
|
||||||
def _handle_signal(self, signum, frame):
|
if self.filename.endswith('.pt'):
|
||||||
|
payload['lineage']['parent']['name'] = self.filename
|
||||||
|
|
||||||
|
self.model.create_model(payload)
|
||||||
|
|
||||||
|
# Model could not be created
|
||||||
|
# TODO: improve error handling
|
||||||
|
if not self.model.id:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}'
|
||||||
|
|
||||||
|
# Start heartbeats for HUB to monitor agent
|
||||||
|
self.model.start_heartbeat(self.rate_limits['heartbeat'])
|
||||||
|
|
||||||
|
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
|
||||||
|
|
||||||
|
def _parse_identifier(self, identifier):
|
||||||
"""
|
"""
|
||||||
Handle kill signals and prevent heartbeats from being sent on Colab after termination.
|
Parses the given identifier to determine the type of identifier and extract relevant components.
|
||||||
|
|
||||||
This method does not use frame, it is included as it is passed by signal.
|
The method supports different identifier formats:
|
||||||
|
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
|
||||||
|
- An identifier containing an API key and a model ID separated by an underscore
|
||||||
|
- An identifier that is solely a model ID of a fixed length
|
||||||
|
- A local filename that ends with '.pt' or '.yaml'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier (str): The identifier string to be parsed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(tuple): A tuple containing the API key, model ID, and filename as applicable.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HUBModelError: If the identifier format is not recognized.
|
||||||
"""
|
"""
|
||||||
if self.alive is True:
|
|
||||||
LOGGER.info(f'{PREFIX}Kill signal received! ❌')
|
|
||||||
self._stop_heartbeat()
|
|
||||||
sys.exit(signum)
|
|
||||||
|
|
||||||
def _stop_heartbeat(self):
|
# Initialize variables
|
||||||
"""Terminate the heartbeat loop."""
|
api_key, model_id, filename = None, None, None
|
||||||
self.alive = False
|
|
||||||
|
# Check if identifier is a HUB URL
|
||||||
|
if identifier.startswith(f'{HUB_WEB_ROOT}/models/'):
|
||||||
|
# Extract the model_id after the HUB_WEB_ROOT URL
|
||||||
|
model_id = identifier.split(f'{HUB_WEB_ROOT}/models/')[-1]
|
||||||
|
else:
|
||||||
|
# Split the identifier based on underscores only if it's not a HUB URL
|
||||||
|
parts = identifier.split('_')
|
||||||
|
|
||||||
|
# Check if identifier is in the format of API key and model ID
|
||||||
|
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
|
||||||
|
api_key, model_id = parts
|
||||||
|
# Check if identifier is a single model ID
|
||||||
|
elif len(parts) == 1 and len(parts[0]) == 20:
|
||||||
|
model_id = parts[0]
|
||||||
|
# Check if identifier is a local filename
|
||||||
|
elif identifier.endswith('.pt') or identifier.endswith('.yaml'):
|
||||||
|
filename = identifier
|
||||||
|
else:
|
||||||
|
raise HUBModelError(
|
||||||
|
f"model='{identifier}' could not be parsed. Check format is correct. "
|
||||||
|
f'Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file.')
|
||||||
|
|
||||||
|
return api_key, model_id, filename
|
||||||
|
|
||||||
|
def _set_train_args(self, **kwargs):
|
||||||
|
if self.model.is_trained():
|
||||||
|
# Model is already trained
|
||||||
|
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
|
||||||
|
|
||||||
|
if self.model.is_resumable():
|
||||||
|
# Model has saved weights
|
||||||
|
self.train_args = {'data': self.model.get_dataset_url(), 'resume': True}
|
||||||
|
self.model_file = self.model.get_weights_url('last')
|
||||||
|
else:
|
||||||
|
# Model has no saved weights
|
||||||
|
def get_train_args(config):
|
||||||
|
return {
|
||||||
|
'batch': config['batchSize'],
|
||||||
|
'epochs': config['epochs'],
|
||||||
|
'imgsz': config['imageSize'],
|
||||||
|
'patience': config['patience'],
|
||||||
|
'device': config['device'],
|
||||||
|
'cache': config['cache'],
|
||||||
|
'data': self.model.get_dataset_url(), }
|
||||||
|
|
||||||
|
self.train_args = get_train_args(self.model.data.get('config'))
|
||||||
|
# Set the model file as either a *.pt or *.yaml file
|
||||||
|
self.model_file = (self.model.get_weights_url('parent')
|
||||||
|
if self.model.is_pretrained() else self.model.get_architecture())
|
||||||
|
|
||||||
|
if not self.train_args.get('data'):
|
||||||
|
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
|
||||||
|
|
||||||
|
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
|
||||||
|
self.model_id = self.model.id
|
||||||
|
|
||||||
|
def request_queue(
|
||||||
|
self,
|
||||||
|
request_func,
|
||||||
|
retry=3,
|
||||||
|
timeout=30,
|
||||||
|
thread=True,
|
||||||
|
verbose=True,
|
||||||
|
progress_total=None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
def retry_request():
|
||||||
|
t0 = time.time() # Record the start time for the timeout
|
||||||
|
for i in range(retry + 1):
|
||||||
|
if (time.time() - t0) > timeout:
|
||||||
|
LOGGER.warning(f'{PREFIX}Timeout for request reached. {HELP_MSG}')
|
||||||
|
break # Timeout reached, exit loop
|
||||||
|
|
||||||
|
response = request_func(*args, **kwargs)
|
||||||
|
if progress_total:
|
||||||
|
self._show_upload_progress(progress_total, response)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
LOGGER.warning(f'{PREFIX}Received no response from the request. {HELP_MSG}')
|
||||||
|
time.sleep(2 ** i) # Exponential backoff before retrying
|
||||||
|
continue # Skip further processing and retry
|
||||||
|
|
||||||
|
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
||||||
|
return response # Success, no need to retry
|
||||||
|
|
||||||
|
if i == 0:
|
||||||
|
# Initial attempt, check status code and provide messages
|
||||||
|
message = self._get_failure_message(response, retry, timeout)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
LOGGER.warning(f'{PREFIX}{message} {HELP_MSG} ({response.status_code})')
|
||||||
|
|
||||||
|
if not self._should_retry(response.status_code):
|
||||||
|
LOGGER.warning(f'{PREFIX}Request failed. {HELP_MSG} ({response.status_code}')
|
||||||
|
break # Not an error that should be retried, exit loop
|
||||||
|
|
||||||
|
time.sleep(2 ** i) # Exponential backoff for retries
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
if thread:
|
||||||
|
# Start a new thread to run the retry_request function
|
||||||
|
threading.Thread(target=retry_request, daemon=True).start()
|
||||||
|
else:
|
||||||
|
# If running in the main thread, call retry_request directly
|
||||||
|
return retry_request()
|
||||||
|
|
||||||
|
def _should_retry(self, status_code):
|
||||||
|
# Status codes that trigger retries
|
||||||
|
retry_codes = {
|
||||||
|
HTTPStatus.REQUEST_TIMEOUT,
|
||||||
|
HTTPStatus.BAD_GATEWAY,
|
||||||
|
HTTPStatus.GATEWAY_TIMEOUT, }
|
||||||
|
return True if status_code in retry_codes else False
|
||||||
|
|
||||||
|
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
|
||||||
|
"""
|
||||||
|
Generate a retry message based on the response status code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The HTTP response object.
|
||||||
|
retry: The number of retry attempts allowed.
|
||||||
|
timeout: The maximum timeout duration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The retry message.
|
||||||
|
"""
|
||||||
|
if self._should_retry(response.status_code):
|
||||||
|
return f'Retrying {retry}x for {timeout}s.' if retry else ''
|
||||||
|
elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
|
||||||
|
headers = response.headers
|
||||||
|
return (f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
|
||||||
|
f"Please retry after {headers['Retry-After']}s.")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return response.json().get('message', 'No JSON message.')
|
||||||
|
except AttributeError:
|
||||||
|
return 'Unable to read JSON.'
|
||||||
|
|
||||||
def upload_metrics(self):
|
def upload_metrics(self):
|
||||||
"""Upload model metrics to Ultralytics HUB."""
|
"""Upload model metrics to Ultralytics HUB."""
|
||||||
payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
|
return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)
|
||||||
smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
|
|
||||||
|
|
||||||
def _get_model(self):
|
def upload_model(
|
||||||
"""Fetch and return model data from Ultralytics HUB."""
|
self,
|
||||||
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
|
epoch: int,
|
||||||
|
weights: str,
|
||||||
try:
|
is_best: bool = False,
|
||||||
response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0)
|
map: float = 0.0,
|
||||||
data = response.json().get('data', None)
|
final: bool = False,
|
||||||
|
) -> None:
|
||||||
if data.get('status', None) == 'trained':
|
|
||||||
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
|
|
||||||
|
|
||||||
if not data.get('data', None):
|
|
||||||
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
|
|
||||||
self.model_id = data['id']
|
|
||||||
|
|
||||||
if data['status'] == 'new': # new model to start training
|
|
||||||
self.train_args = {
|
|
||||||
'batch': data['batch_size'], # note HUB argument is slightly different
|
|
||||||
'epochs': data['epochs'],
|
|
||||||
'imgsz': data['imgsz'],
|
|
||||||
'patience': data['patience'],
|
|
||||||
'device': data['device'],
|
|
||||||
'cache': data['cache'],
|
|
||||||
'data': data['data']}
|
|
||||||
self.model_file = data.get('cfg') or data.get('weights') # cfg for pretrained=False
|
|
||||||
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
|
|
||||||
elif data['status'] == 'training': # existing model to resume training
|
|
||||||
self.train_args = {'data': data['data'], 'resume': True}
|
|
||||||
self.model_file = data['resume']
|
|
||||||
|
|
||||||
return data
|
|
||||||
except requests.exceptions.ConnectionError as e:
|
|
||||||
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
|
|
||||||
except Exception:
|
|
||||||
raise
|
|
||||||
|
|
||||||
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
|
||||||
"""
|
"""
|
||||||
Upload a model checkpoint to Ultralytics HUB.
|
Upload a model checkpoint to Ultralytics HUB.
|
||||||
|
|
||||||
@ -149,43 +303,33 @@ class HUBTrainingSession:
|
|||||||
final (bool): Indicates if the model is the final model after training.
|
final (bool): Indicates if the model is the final model after training.
|
||||||
"""
|
"""
|
||||||
if Path(weights).is_file():
|
if Path(weights).is_file():
|
||||||
with open(weights, 'rb') as f:
|
progress_total = (Path(weights).stat().st_size if final else None) # Only show progress if final
|
||||||
file = f.read()
|
self.request_queue(
|
||||||
|
self.model.upload_model,
|
||||||
|
epoch=epoch,
|
||||||
|
weights=weights,
|
||||||
|
is_best=is_best,
|
||||||
|
map=map,
|
||||||
|
final=final,
|
||||||
|
retry=10,
|
||||||
|
timeout=3600,
|
||||||
|
thread=not final,
|
||||||
|
progress_total=progress_total,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
|
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
|
||||||
file = None
|
|
||||||
url = f'{self.api_url}/upload'
|
|
||||||
# url = 'http://httpbin.org/post' # for debug
|
|
||||||
data = {'epoch': epoch}
|
|
||||||
if final:
|
|
||||||
data.update({'type': 'final', 'map': map})
|
|
||||||
filesize = Path(weights).stat().st_size
|
|
||||||
smart_request('post',
|
|
||||||
url,
|
|
||||||
data=data,
|
|
||||||
files={'best.pt': file},
|
|
||||||
headers=self.auth_header,
|
|
||||||
retry=10,
|
|
||||||
timeout=3600,
|
|
||||||
thread=False,
|
|
||||||
progress=filesize,
|
|
||||||
code=4)
|
|
||||||
else:
|
|
||||||
data.update({'type': 'epoch', 'isBest': bool(is_best)})
|
|
||||||
smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3)
|
|
||||||
|
|
||||||
@threaded
|
def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
|
||||||
def _start_heartbeat(self):
|
"""
|
||||||
"""Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB."""
|
Display a progress bar to track the upload progress of a file download.
|
||||||
while self.alive:
|
|
||||||
r = smart_request('post',
|
Args:
|
||||||
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
|
content_length (int): The total size of the content to be downloaded in bytes.
|
||||||
json={
|
response (requests.Response): The response object from the file download request.
|
||||||
'agent': AGENT_NAME,
|
|
||||||
'agentId': self.agent_id},
|
Returns:
|
||||||
headers=self.auth_header,
|
(None)
|
||||||
retry=0,
|
"""
|
||||||
code=5,
|
with TQDM(total=content_length, unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
||||||
thread=False) # already in a thread
|
for data in response.iter_content(chunk_size=1024):
|
||||||
self.agent_id = r.json().get('data', {}).get('agentId', None)
|
pbar.update(len(data))
|
||||||
sleep(self.rate_limits['heartbeat'])
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
import os
|
|
||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
@ -16,8 +15,6 @@ from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
|
|||||||
|
|
||||||
PREFIX = colorstr('Ultralytics HUB: ')
|
PREFIX = colorstr('Ultralytics HUB: ')
|
||||||
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
||||||
HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
|
|
||||||
HUB_WEB_ROOT = os.environ.get('ULTRALYTICS_HUB_WEB', 'https://hub.ultralytics.com')
|
|
||||||
|
|
||||||
|
|
||||||
def request_with_credentials(url: str) -> any:
|
def request_with_credentials(url: str) -> any:
|
||||||
|
@ -3,7 +3,9 @@
|
|||||||
import json
|
import json
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events
|
from hub_sdk.config import HUB_WEB_ROOT
|
||||||
|
|
||||||
|
from ultralytics.hub.utils import PREFIX, events
|
||||||
from ultralytics.utils import LOGGER, SETTINGS
|
from ultralytics.utils import LOGGER, SETTINGS
|
||||||
|
|
||||||
|
|
||||||
@ -12,8 +14,9 @@ def on_pretrain_routine_end(trainer):
|
|||||||
session = getattr(trainer, 'hub_session', None)
|
session = getattr(trainer, 'hub_session', None)
|
||||||
if session:
|
if session:
|
||||||
# Start timer for upload rate limit
|
# Start timer for upload rate limit
|
||||||
LOGGER.info(f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀')
|
session.timers = {
|
||||||
session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit
|
'metrics': time(),
|
||||||
|
'ckpt': time(), } # start timer on session.rate_limit
|
||||||
|
|
||||||
|
|
||||||
def on_fit_epoch_end(trainer):
|
def on_fit_epoch_end(trainer):
|
||||||
@ -21,10 +24,13 @@ def on_fit_epoch_end(trainer):
|
|||||||
session = getattr(trainer, 'hub_session', None)
|
session = getattr(trainer, 'hub_session', None)
|
||||||
if session:
|
if session:
|
||||||
# Upload metrics after val end
|
# Upload metrics after val end
|
||||||
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
|
all_plots = {
|
||||||
|
**trainer.label_loss_items(trainer.tloss, prefix='train'),
|
||||||
|
**trainer.metrics, }
|
||||||
if trainer.epoch == 0:
|
if trainer.epoch == 0:
|
||||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||||
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
|
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
|
||||||
|
|
||||||
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
||||||
if time() - session.timers['metrics'] > session.rate_limits['metrics']:
|
if time() - session.timers['metrics'] > session.rate_limits['metrics']:
|
||||||
session.upload_metrics()
|
session.upload_metrics()
|
||||||
@ -39,7 +45,7 @@ def on_model_save(trainer):
|
|||||||
# Upload checkpoints with rate limiting
|
# Upload checkpoints with rate limiting
|
||||||
is_best = trainer.best_fitness == trainer.fitness
|
is_best = trainer.best_fitness == trainer.fitness
|
||||||
if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
|
if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
|
||||||
LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}')
|
LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}')
|
||||||
session.upload_model(trainer.epoch, trainer.last, is_best)
|
session.upload_model(trainer.epoch, trainer.last, is_best)
|
||||||
session.timers['ckpt'] = time() # reset timer
|
session.timers['ckpt'] = time() # reset timer
|
||||||
|
|
||||||
@ -50,10 +56,15 @@ def on_train_end(trainer):
|
|||||||
if session:
|
if session:
|
||||||
# Upload final model and metrics with exponential standoff
|
# Upload final model and metrics with exponential standoff
|
||||||
LOGGER.info(f'{PREFIX}Syncing final model...')
|
LOGGER.info(f'{PREFIX}Syncing final model...')
|
||||||
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
|
session.upload_model(
|
||||||
|
trainer.epoch,
|
||||||
|
trainer.best,
|
||||||
|
map=trainer.metrics.get('metrics/mAP50-95(B)', 0),
|
||||||
|
final=True,
|
||||||
|
)
|
||||||
session.alive = False # stop heartbeats
|
session.alive = False # stop heartbeats
|
||||||
LOGGER.info(f'{PREFIX}Done ✅\n'
|
LOGGER.info(f'{PREFIX}Done ✅\n'
|
||||||
f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀')
|
f'{PREFIX}View model at {session.model_url} 🚀')
|
||||||
|
|
||||||
|
|
||||||
def on_train_start(trainer):
|
def on_train_start(trainer):
|
||||||
@ -76,7 +87,7 @@ def on_export_start(exporter):
|
|||||||
events(exporter.args)
|
events(exporter.args)
|
||||||
|
|
||||||
|
|
||||||
callbacks = {
|
callbacks = ({
|
||||||
'on_pretrain_routine_end': on_pretrain_routine_end,
|
'on_pretrain_routine_end': on_pretrain_routine_end,
|
||||||
'on_fit_epoch_end': on_fit_epoch_end,
|
'on_fit_epoch_end': on_fit_epoch_end,
|
||||||
'on_model_save': on_model_save,
|
'on_model_save': on_model_save,
|
||||||
@ -84,4 +95,4 @@ callbacks = {
|
|||||||
'on_train_start': on_train_start,
|
'on_train_start': on_train_start,
|
||||||
'on_val_start': on_val_start,
|
'on_val_start': on_val_start,
|
||||||
'on_predict_start': on_predict_start,
|
'on_predict_start': on_predict_start,
|
||||||
'on_export_start': on_export_start} if SETTINGS['hub'] is True else {} # verify enabled
|
'on_export_start': on_export_start, } if SETTINGS['hub'] is True else {}) # verify enabled
|
||||||
|
Loading…
x
Reference in New Issue
Block a user