Feature: Create HUB Models from CLI or Python Script (#7316)

Co-authored-by: Hassaan Farooq <103611273+hassaanfarooq01@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Kalen Michael 2024-01-10 02:36:14 +01:00 committed by GitHub
parent a92adf8231
commit b54055a2c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 356 additions and 154 deletions

View File

@ -78,6 +78,7 @@ dependencies = [
"thop>=0.1.1", # FLOPs computation
"pandas>=1.1.4",
"seaborn>=0.11.0", # plotting
"hub-sdk>=0.0.2", # Ultralytics HUB
]
# Optional dependencies ------------------------------------------------------------------------------------------------

View File

@ -5,10 +5,11 @@ import sys
from pathlib import Path
from typing import Union
from hub_sdk.config import HUB_WEB_ROOT
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.hub.utils import HUB_WEB_ROOT
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, checks, emojis, yaml_load
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load
class Model(nn.Module):
@ -76,8 +77,8 @@ class Model(nn.Module):
# Check if Ultralytics HUB model from https://hub.ultralytics.com
if self.is_hub_model(model):
from ultralytics.hub.session import HUBTrainingSession
self.session = HUBTrainingSession(model)
# Fetch model from HUB
self.session = self._get_hub_session(model)
model = self.session.model_file
# Check if Triton Server model
@ -93,10 +94,20 @@ class Model(nn.Module):
else:
self._load(model, task)
self.model_name = model
def __call__(self, source=None, stream=False, **kwargs):
"""Calls the predict() method with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
@staticmethod
def _get_hub_session(model: str):
"""Creates a session for Hub Training."""
from ultralytics.hub.session import HUBTrainingSession
session = HUBTrainingSession(model)
return session if session.client.authenticated else None
@staticmethod
def is_triton_model(model):
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
@ -336,10 +347,11 @@ class Model(nn.Module):
**kwargs (Any): Any number of arguments representing the training configuration.
"""
self._check_is_pytorch_model()
if self.session: # Ultralytics HUB session
if hasattr(self.session, 'model') and self.session.model.id: # Ultralytics HUB session with loaded model
if any(kwargs):
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
kwargs = self.session.train_args
kwargs = self.session.train_args # overwrite kwargs
checks.check_pip_update_available()
overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
@ -352,6 +364,20 @@ class Model(nn.Module):
if not args.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
if SETTINGS['hub'] is True and not self.session:
# Create a model in HUB
try:
self.session = self._get_hub_session(self.model_name)
if self.session:
self.session.create_model(args)
# Check model was created
if not getattr(self.session.model, 'id', None):
self.session = None
except PermissionError:
# Ignore permission error
pass
self.trainer.hub_session = self.session # attach optional HUB session
self.trainer.train()
# Update model and cfg after training

View File

@ -1,28 +1,49 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import requests
from hub_sdk import HUB_API_ROOT, HUB_WEB_ROOT, HUBClient
from ultralytics.data.utils import HUBDatasetStats
from ultralytics.hub.auth import Auth
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
from ultralytics.hub.utils import PREFIX
from ultralytics.utils import LOGGER, SETTINGS
def login(api_key=''):
def login(api_key: str = None, save=True) -> bool:
"""
Log in to the Ultralytics HUB API using the provided API key.
The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY environment variable if successfully authenticated.
Args:
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
Example:
```python
from ultralytics import hub
hub.login('API_KEY')
```
api_key (str, optional): The API key to use for authentication. If not provided, it will be retrieved from SETTINGS or HUB_API_KEY environment variable.
save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
Returns:
bool: True if authentication is successful, False otherwise.
"""
Auth(api_key, verbose=True)
api_key_url = f'{HUB_WEB_ROOT}/settings?tab=api+keys' # Set the redirect URL
saved_key = SETTINGS.get('api_key')
active_key = api_key or saved_key
credentials = {'api_key': active_key} if active_key and active_key != '' else None # Set credentials
client = HUBClient(credentials) # initialize HUBClient
if client.authenticated:
# Successfully authenticated with HUB
if save and client.api_key != saved_key:
SETTINGS.update({'api_key': client.api_key}) # update settings with valid API key
# Set message based on whether key was provided or retrieved from settings
log_message = ('New authentication successful ✅'
if client.api_key == api_key or not credentials else 'Authenticated ✅')
LOGGER.info(f'{PREFIX}{log_message}')
return True
else:
# Failed to authenticate with HUB
LOGGER.info(f'{PREFIX}Retrieve API key from {api_key_url}')
return False
def logout():
@ -43,7 +64,7 @@ def logout():
def reset_model(model_id=''):
"""Reset a trained model to an untrained state."""
r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'modelId': model_id}, headers={'x-api-key': Auth().api_key})
if r.status_code == 200:
LOGGER.info(f'{PREFIX}Model reset successfully')
return
@ -73,7 +94,8 @@ def get_export(model_id='', format='torchscript'):
json={
'apiKey': Auth().api_key,
'modelId': model_id,
'format': format})
'format': format},
headers={'x-api-key': Auth().api_key})
assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
return r.json()

View File

@ -1,8 +1,9 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import requests
from hub_sdk import HUB_API_ROOT, HUB_WEB_ROOT
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials
from ultralytics.hub.utils import PREFIX, request_with_credentials
from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab
API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys'

View File

@ -1,17 +1,18 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import signal
import sys
import threading
import time
from http import HTTPStatus
from pathlib import Path
from time import sleep
import requests
from hub_sdk import HUB_WEB_ROOT, HUBClient
from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, smart_request
from ultralytics.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
from ultralytics.hub.utils import HELP_MSG, PREFIX, TQDM
from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab
from ultralytics.utils.errors import HUBModelError
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
AGENT_NAME = (f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local')
class HUBTrainingSession:
@ -34,7 +35,7 @@ class HUBTrainingSession:
alive (bool): Indicates if the heartbeat loop is active.
"""
def __init__(self, url):
def __init__(self, identifier):
"""
Initialize the HUBTrainingSession with the provided model identifier.
@ -46,98 +47,251 @@ class HUBTrainingSession:
ValueError: If the provided model identifier is invalid.
ConnectionError: If connecting with global API key is not supported.
"""
from ultralytics.hub.auth import Auth
self.rate_limits = {
'metrics': 3.0,
'ckpt': 900.0,
'heartbeat': 300.0, } # rate limits (seconds)
self.metrics_queue = {} # holds metrics for each epoch until upload
self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py
# Parse input
if url.startswith(f'{HUB_WEB_ROOT}/models/'):
url = url.split(f'{HUB_WEB_ROOT}/models/')[-1]
if [len(x) for x in url.split('_')] == [42, 20]:
key, model_id = url.split('_')
elif len(url) == 20:
key, model_id = '', url
else:
raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. "
f"model='{HUB_WEB_ROOT}/models/MODEL_ID' and try again.")
api_key, model_id, self.filename = self._parse_identifier(identifier)
# Authorize
auth = Auth(key)
self.agent_id = None # identifies which instance is communicating with server
self.model_id = model_id
self.model_url = f'{HUB_WEB_ROOT}/models/{model_id}'
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
self.auth_header = auth.get_auth_header()
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
self.timers = {} # rate limit timers (seconds)
self.metrics_queue = {} # metrics queue
self.model = self._get_model()
self.alive = True
self._start_heartbeat() # start heartbeats
self._register_signal_handlers()
# Get credentials
active_key = api_key or SETTINGS.get('api_key')
credentials = {'api_key': active_key} if active_key else None # set credentials
# Initialize client
self.client = HUBClient(credentials)
if model_id:
self.load_model(model_id) # load existing model
else:
self.model = self.client.model() # load empty model
def load_model(self, model_id):
# Initialize model
self.model = self.client.model(model_id)
self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}'
self._set_train_args()
# Start heartbeats for HUB to monitor agent
self.model.start_heartbeat(self.rate_limits['heartbeat'])
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
def _register_signal_handlers(self):
"""Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination."""
signal.signal(signal.SIGTERM, self._handle_signal)
signal.signal(signal.SIGINT, self._handle_signal)
def create_model(self, model_args):
# Initialize model
payload = {
'config': {
'batchSize': model_args.get('batch', -1),
'epochs': model_args.get('epochs', 300),
'imageSize': model_args.get('imgsz', 640),
'patience': model_args.get('patience', 100),
'device': model_args.get('device', ''),
'cache': model_args.get('cache', 'ram'), },
'dataset': {
'name': model_args.get('data')},
'lineage': {
'architecture': {
'name': self.filename.replace('.pt', '').replace('.yaml', ''), },
'parent': {}, },
'meta': {
'name': self.filename}, }
def _handle_signal(self, signum, frame):
if self.filename.endswith('.pt'):
payload['lineage']['parent']['name'] = self.filename
self.model.create_model(payload)
# Model could not be created
# TODO: improve error handling
if not self.model.id:
return
self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}'
# Start heartbeats for HUB to monitor agent
self.model.start_heartbeat(self.rate_limits['heartbeat'])
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
def _parse_identifier(self, identifier):
"""
Handle kill signals and prevent heartbeats from being sent on Colab after termination.
Parses the given identifier to determine the type of identifier and extract relevant components.
This method does not use frame, it is included as it is passed by signal.
The method supports different identifier formats:
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
- An identifier containing an API key and a model ID separated by an underscore
- An identifier that is solely a model ID of a fixed length
- A local filename that ends with '.pt' or '.yaml'
Args:
identifier (str): The identifier string to be parsed.
Returns:
(tuple): A tuple containing the API key, model ID, and filename as applicable.
Raises:
HUBModelError: If the identifier format is not recognized.
"""
if self.alive is True:
LOGGER.info(f'{PREFIX}Kill signal received! ❌')
self._stop_heartbeat()
sys.exit(signum)
def _stop_heartbeat(self):
"""Terminate the heartbeat loop."""
self.alive = False
# Initialize variables
api_key, model_id, filename = None, None, None
# Check if identifier is a HUB URL
if identifier.startswith(f'{HUB_WEB_ROOT}/models/'):
# Extract the model_id after the HUB_WEB_ROOT URL
model_id = identifier.split(f'{HUB_WEB_ROOT}/models/')[-1]
else:
# Split the identifier based on underscores only if it's not a HUB URL
parts = identifier.split('_')
# Check if identifier is in the format of API key and model ID
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
api_key, model_id = parts
# Check if identifier is a single model ID
elif len(parts) == 1 and len(parts[0]) == 20:
model_id = parts[0]
# Check if identifier is a local filename
elif identifier.endswith('.pt') or identifier.endswith('.yaml'):
filename = identifier
else:
raise HUBModelError(
f"model='{identifier}' could not be parsed. Check format is correct. "
f'Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file.')
return api_key, model_id, filename
def _set_train_args(self, **kwargs):
if self.model.is_trained():
# Model is already trained
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
if self.model.is_resumable():
# Model has saved weights
self.train_args = {'data': self.model.get_dataset_url(), 'resume': True}
self.model_file = self.model.get_weights_url('last')
else:
# Model has no saved weights
def get_train_args(config):
return {
'batch': config['batchSize'],
'epochs': config['epochs'],
'imgsz': config['imageSize'],
'patience': config['patience'],
'device': config['device'],
'cache': config['cache'],
'data': self.model.get_dataset_url(), }
self.train_args = get_train_args(self.model.data.get('config'))
# Set the model file as either a *.pt or *.yaml file
self.model_file = (self.model.get_weights_url('parent')
if self.model.is_pretrained() else self.model.get_architecture())
if not self.train_args.get('data'):
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
self.model_id = self.model.id
def request_queue(
self,
request_func,
retry=3,
timeout=30,
thread=True,
verbose=True,
progress_total=None,
*args,
**kwargs,
):
def retry_request():
t0 = time.time() # Record the start time for the timeout
for i in range(retry + 1):
if (time.time() - t0) > timeout:
LOGGER.warning(f'{PREFIX}Timeout for request reached. {HELP_MSG}')
break # Timeout reached, exit loop
response = request_func(*args, **kwargs)
if progress_total:
self._show_upload_progress(progress_total, response)
if response is None:
LOGGER.warning(f'{PREFIX}Received no response from the request. {HELP_MSG}')
time.sleep(2 ** i) # Exponential backoff before retrying
continue # Skip further processing and retry
if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES:
return response # Success, no need to retry
if i == 0:
# Initial attempt, check status code and provide messages
message = self._get_failure_message(response, retry, timeout)
if verbose:
LOGGER.warning(f'{PREFIX}{message} {HELP_MSG} ({response.status_code})')
if not self._should_retry(response.status_code):
LOGGER.warning(f'{PREFIX}Request failed. {HELP_MSG} ({response.status_code}')
break # Not an error that should be retried, exit loop
time.sleep(2 ** i) # Exponential backoff for retries
return response
if thread:
# Start a new thread to run the retry_request function
threading.Thread(target=retry_request, daemon=True).start()
else:
# If running in the main thread, call retry_request directly
return retry_request()
def _should_retry(self, status_code):
# Status codes that trigger retries
retry_codes = {
HTTPStatus.REQUEST_TIMEOUT,
HTTPStatus.BAD_GATEWAY,
HTTPStatus.GATEWAY_TIMEOUT, }
return True if status_code in retry_codes else False
def _get_failure_message(self, response: requests.Response, retry: int, timeout: int):
"""
Generate a retry message based on the response status code.
Args:
response: The HTTP response object.
retry: The number of retry attempts allowed.
timeout: The maximum timeout duration.
Returns:
str: The retry message.
"""
if self._should_retry(response.status_code):
return f'Retrying {retry}x for {timeout}s.' if retry else ''
elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit
headers = response.headers
return (f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). "
f"Please retry after {headers['Retry-After']}s.")
else:
try:
return response.json().get('message', 'No JSON message.')
except AttributeError:
return 'Unable to read JSON.'
def upload_metrics(self):
"""Upload model metrics to Ultralytics HUB."""
payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True)
def _get_model(self):
"""Fetch and return model data from Ultralytics HUB."""
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
try:
response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0)
data = response.json().get('data', None)
if data.get('status', None) == 'trained':
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
if not data.get('data', None):
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
self.model_id = data['id']
if data['status'] == 'new': # new model to start training
self.train_args = {
'batch': data['batch_size'], # note HUB argument is slightly different
'epochs': data['epochs'],
'imgsz': data['imgsz'],
'patience': data['patience'],
'device': data['device'],
'cache': data['cache'],
'data': data['data']}
self.model_file = data.get('cfg') or data.get('weights') # cfg for pretrained=False
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
elif data['status'] == 'training': # existing model to resume training
self.train_args = {'data': data['data'], 'resume': True}
self.model_file = data['resume']
return data
except requests.exceptions.ConnectionError as e:
raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e
except Exception:
raise
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
def upload_model(
self,
epoch: int,
weights: str,
is_best: bool = False,
map: float = 0.0,
final: bool = False,
) -> None:
"""
Upload a model checkpoint to Ultralytics HUB.
@ -149,43 +303,33 @@ class HUBTrainingSession:
final (bool): Indicates if the model is the final model after training.
"""
if Path(weights).is_file():
with open(weights, 'rb') as f:
file = f.read()
progress_total = (Path(weights).stat().st_size if final else None) # Only show progress if final
self.request_queue(
self.model.upload_model,
epoch=epoch,
weights=weights,
is_best=is_best,
map=map,
final=final,
retry=10,
timeout=3600,
thread=not final,
progress_total=progress_total,
)
else:
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.')
file = None
url = f'{self.api_url}/upload'
# url = 'http://httpbin.org/post' # for debug
data = {'epoch': epoch}
if final:
data.update({'type': 'final', 'map': map})
filesize = Path(weights).stat().st_size
smart_request('post',
url,
data=data,
files={'best.pt': file},
headers=self.auth_header,
retry=10,
timeout=3600,
thread=False,
progress=filesize,
code=4)
else:
data.update({'type': 'epoch', 'isBest': bool(is_best)})
smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3)
@threaded
def _start_heartbeat(self):
"""Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB."""
while self.alive:
r = smart_request('post',
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
json={
'agent': AGENT_NAME,
'agentId': self.agent_id},
headers=self.auth_header,
retry=0,
code=5,
thread=False) # already in a thread
self.agent_id = r.json().get('data', {}).get('agentId', None)
sleep(self.rate_limits['heartbeat'])
def _show_upload_progress(self, content_length: int, response: requests.Response) -> None:
"""
Display a progress bar to track the upload progress of a file download.
Args:
content_length (int): The total size of the content to be downloaded in bytes.
response (requests.Response): The response object from the file download request.
Returns:
(None)
"""
with TQDM(total=content_length, unit='B', unit_scale=True, unit_divisor=1024) as pbar:
for data in response.iter_content(chunk_size=1024):
pbar.update(len(data))

View File

@ -1,6 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
import platform
import random
import sys
@ -16,8 +15,6 @@ from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
PREFIX = colorstr('Ultralytics HUB: ')
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
HUB_WEB_ROOT = os.environ.get('ULTRALYTICS_HUB_WEB', 'https://hub.ultralytics.com')
def request_with_credentials(url: str) -> any:

View File

@ -3,7 +3,9 @@
import json
from time import time
from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events
from hub_sdk.config import HUB_WEB_ROOT
from ultralytics.hub.utils import PREFIX, events
from ultralytics.utils import LOGGER, SETTINGS
@ -12,8 +14,9 @@ def on_pretrain_routine_end(trainer):
session = getattr(trainer, 'hub_session', None)
if session:
# Start timer for upload rate limit
LOGGER.info(f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀')
session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit
session.timers = {
'metrics': time(),
'ckpt': time(), } # start timer on session.rate_limit
def on_fit_epoch_end(trainer):
@ -21,10 +24,13 @@ def on_fit_epoch_end(trainer):
session = getattr(trainer, 'hub_session', None)
if session:
# Upload metrics after val end
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
all_plots = {
**trainer.label_loss_items(trainer.tloss, prefix='train'),
**trainer.metrics, }
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
if time() - session.timers['metrics'] > session.rate_limits['metrics']:
session.upload_metrics()
@ -39,7 +45,7 @@ def on_model_save(trainer):
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}')
LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}')
session.upload_model(trainer.epoch, trainer.last, is_best)
session.timers['ckpt'] = time() # reset timer
@ -50,10 +56,15 @@ def on_train_end(trainer):
if session:
# Upload final model and metrics with exponential standoff
LOGGER.info(f'{PREFIX}Syncing final model...')
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
session.upload_model(
trainer.epoch,
trainer.best,
map=trainer.metrics.get('metrics/mAP50-95(B)', 0),
final=True,
)
session.alive = False # stop heartbeats
LOGGER.info(f'{PREFIX}Done ✅\n'
f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀')
f'{PREFIX}View model at {session.model_url} 🚀')
def on_train_start(trainer):
@ -76,7 +87,7 @@ def on_export_start(exporter):
events(exporter.args)
callbacks = {
callbacks = ({
'on_pretrain_routine_end': on_pretrain_routine_end,
'on_fit_epoch_end': on_fit_epoch_end,
'on_model_save': on_model_save,
@ -84,4 +95,4 @@ callbacks = {
'on_train_start': on_train_start,
'on_val_start': on_val_start,
'on_predict_start': on_predict_start,
'on_export_start': on_export_start} if SETTINGS['hub'] is True else {} # verify enabled
'on_export_start': on_export_start, } if SETTINGS['hub'] is True else {}) # verify enabled