diff --git a/pyproject.toml b/pyproject.toml index 9dced7e8..7b734356 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ dependencies = [ "thop>=0.1.1", # FLOPs computation "pandas>=1.1.4", "seaborn>=0.11.0", # plotting + "hub-sdk>=0.0.2", # Ultralytics HUB ] # Optional dependencies ------------------------------------------------------------------------------------------------ diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index ced3f18b..66052f30 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -5,10 +5,11 @@ import sys from pathlib import Path from typing import Union +from hub_sdk.config import HUB_WEB_ROOT + from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir -from ultralytics.hub.utils import HUB_WEB_ROOT from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load -from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, checks, emojis, yaml_load +from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, SETTINGS, callbacks, checks, emojis, yaml_load class Model(nn.Module): @@ -76,8 +77,8 @@ class Model(nn.Module): # Check if Ultralytics HUB model from https://hub.ultralytics.com if self.is_hub_model(model): - from ultralytics.hub.session import HUBTrainingSession - self.session = HUBTrainingSession(model) + # Fetch model from HUB + self.session = self._get_hub_session(model) model = self.session.model_file # Check if Triton Server model @@ -93,10 +94,20 @@ class Model(nn.Module): else: self._load(model, task) + self.model_name = model + def __call__(self, source=None, stream=False, **kwargs): """Calls the predict() method with given arguments to perform object detection.""" return self.predict(source, stream, **kwargs) + @staticmethod + def _get_hub_session(model: str): + """Creates a session for Hub Training.""" + from ultralytics.hub.session import HUBTrainingSession + + session = HUBTrainingSession(model) + return session if session.client.authenticated else None + @staticmethod def is_triton_model(model): """Is model a Triton Server URL string, i.e. :////""" @@ -336,10 +347,11 @@ class Model(nn.Module): **kwargs (Any): Any number of arguments representing the training configuration. """ self._check_is_pytorch_model() - if self.session: # Ultralytics HUB session + if hasattr(self.session, 'model') and self.session.model.id: # Ultralytics HUB session with loaded model if any(kwargs): LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.') - kwargs = self.session.train_args + kwargs = self.session.train_args # overwrite kwargs + checks.check_pip_update_available() overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides @@ -352,6 +364,20 @@ class Model(nn.Module): if not args.get('resume'): # manually set model only if not resuming self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) self.model = self.trainer.model + + if SETTINGS['hub'] is True and not self.session: + # Create a model in HUB + try: + self.session = self._get_hub_session(self.model_name) + if self.session: + self.session.create_model(args) + # Check model was created + if not getattr(self.session.model, 'id', None): + self.session = None + except PermissionError: + # Ignore permission error + pass + self.trainer.hub_session = self.session # attach optional HUB session self.trainer.train() # Update model and cfg after training diff --git a/ultralytics/hub/__init__.py b/ultralytics/hub/__init__.py index 8e101d6b..13bc9246 100644 --- a/ultralytics/hub/__init__.py +++ b/ultralytics/hub/__init__.py @@ -1,28 +1,49 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license import requests +from hub_sdk import HUB_API_ROOT, HUB_WEB_ROOT, HUBClient from ultralytics.data.utils import HUBDatasetStats from ultralytics.hub.auth import Auth -from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX +from ultralytics.hub.utils import PREFIX from ultralytics.utils import LOGGER, SETTINGS -def login(api_key=''): +def login(api_key: str = None, save=True) -> bool: """ Log in to the Ultralytics HUB API using the provided API key. + The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY environment variable if successfully authenticated. + Args: - api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id - - Example: - ```python - from ultralytics import hub - - hub.login('API_KEY') - ``` + api_key (str, optional): The API key to use for authentication. If not provided, it will be retrieved from SETTINGS or HUB_API_KEY environment variable. + save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful. + Returns: + bool: True if authentication is successful, False otherwise. """ - Auth(api_key, verbose=True) + api_key_url = f'{HUB_WEB_ROOT}/settings?tab=api+keys' # Set the redirect URL + saved_key = SETTINGS.get('api_key') + active_key = api_key or saved_key + credentials = {'api_key': active_key} if active_key and active_key != '' else None # Set credentials + + client = HUBClient(credentials) # initialize HUBClient + + if client.authenticated: + # Successfully authenticated with HUB + + if save and client.api_key != saved_key: + SETTINGS.update({'api_key': client.api_key}) # update settings with valid API key + + # Set message based on whether key was provided or retrieved from settings + log_message = ('New authentication successful ✅' + if client.api_key == api_key or not credentials else 'Authenticated ✅') + LOGGER.info(f'{PREFIX}{log_message}') + + return True + else: + # Failed to authenticate with HUB + LOGGER.info(f'{PREFIX}Retrieve API key from {api_key_url}') + return False def logout(): @@ -43,7 +64,7 @@ def logout(): def reset_model(model_id=''): """Reset a trained model to an untrained state.""" - r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id}) + r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'modelId': model_id}, headers={'x-api-key': Auth().api_key}) if r.status_code == 200: LOGGER.info(f'{PREFIX}Model reset successfully') return @@ -73,7 +94,8 @@ def get_export(model_id='', format='torchscript'): json={ 'apiKey': Auth().api_key, 'modelId': model_id, - 'format': format}) + 'format': format}, + headers={'x-api-key': Auth().api_key}) assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}' return r.json() diff --git a/ultralytics/hub/auth.py b/ultralytics/hub/auth.py index deea9a32..202ca8c4 100644 --- a/ultralytics/hub/auth.py +++ b/ultralytics/hub/auth.py @@ -1,8 +1,9 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license import requests +from hub_sdk import HUB_API_ROOT, HUB_WEB_ROOT -from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials +from ultralytics.hub.utils import PREFIX, request_with_credentials from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys' diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index d2fd89a9..dd3d01c0 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -1,17 +1,18 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -import signal -import sys +import threading +import time +from http import HTTPStatus from pathlib import Path -from time import sleep import requests +from hub_sdk import HUB_WEB_ROOT, HUBClient -from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, smart_request -from ultralytics.utils import LOGGER, __version__, checks, emojis, is_colab, threaded +from ultralytics.hub.utils import HELP_MSG, PREFIX, TQDM +from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab from ultralytics.utils.errors import HUBModelError -AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local' +AGENT_NAME = (f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local') class HUBTrainingSession: @@ -34,7 +35,7 @@ class HUBTrainingSession: alive (bool): Indicates if the heartbeat loop is active. """ - def __init__(self, url): + def __init__(self, identifier): """ Initialize the HUBTrainingSession with the provided model identifier. @@ -46,98 +47,251 @@ class HUBTrainingSession: ValueError: If the provided model identifier is invalid. ConnectionError: If connecting with global API key is not supported. """ - - from ultralytics.hub.auth import Auth + self.rate_limits = { + 'metrics': 3.0, + 'ckpt': 900.0, + 'heartbeat': 300.0, } # rate limits (seconds) + self.metrics_queue = {} # holds metrics for each epoch until upload + self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py # Parse input - if url.startswith(f'{HUB_WEB_ROOT}/models/'): - url = url.split(f'{HUB_WEB_ROOT}/models/')[-1] - if [len(x) for x in url.split('_')] == [42, 20]: - key, model_id = url.split('_') - elif len(url) == 20: - key, model_id = '', url - else: - raise HUBModelError(f"model='{url}' not found. Check format is correct, i.e. " - f"model='{HUB_WEB_ROOT}/models/MODEL_ID' and try again.") + api_key, model_id, self.filename = self._parse_identifier(identifier) - # Authorize - auth = Auth(key) - self.agent_id = None # identifies which instance is communicating with server - self.model_id = model_id - self.model_url = f'{HUB_WEB_ROOT}/models/{model_id}' - self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}' - self.auth_header = auth.get_auth_header() - self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds) - self.timers = {} # rate limit timers (seconds) - self.metrics_queue = {} # metrics queue - self.model = self._get_model() - self.alive = True - self._start_heartbeat() # start heartbeats - self._register_signal_handlers() + # Get credentials + active_key = api_key or SETTINGS.get('api_key') + credentials = {'api_key': active_key} if active_key else None # set credentials + + # Initialize client + self.client = HUBClient(credentials) + + if model_id: + self.load_model(model_id) # load existing model + else: + self.model = self.client.model() # load empty model + + def load_model(self, model_id): + # Initialize model + self.model = self.client.model(model_id) + self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}' + + self._set_train_args() + + # Start heartbeats for HUB to monitor agent + self.model.start_heartbeat(self.rate_limits['heartbeat']) LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀') - def _register_signal_handlers(self): - """Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination.""" - signal.signal(signal.SIGTERM, self._handle_signal) - signal.signal(signal.SIGINT, self._handle_signal) + def create_model(self, model_args): + # Initialize model + payload = { + 'config': { + 'batchSize': model_args.get('batch', -1), + 'epochs': model_args.get('epochs', 300), + 'imageSize': model_args.get('imgsz', 640), + 'patience': model_args.get('patience', 100), + 'device': model_args.get('device', ''), + 'cache': model_args.get('cache', 'ram'), }, + 'dataset': { + 'name': model_args.get('data')}, + 'lineage': { + 'architecture': { + 'name': self.filename.replace('.pt', '').replace('.yaml', ''), }, + 'parent': {}, }, + 'meta': { + 'name': self.filename}, } - def _handle_signal(self, signum, frame): + if self.filename.endswith('.pt'): + payload['lineage']['parent']['name'] = self.filename + + self.model.create_model(payload) + + # Model could not be created + # TODO: improve error handling + if not self.model.id: + return + + self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}' + + # Start heartbeats for HUB to monitor agent + self.model.start_heartbeat(self.rate_limits['heartbeat']) + + LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀') + + def _parse_identifier(self, identifier): """ - Handle kill signals and prevent heartbeats from being sent on Colab after termination. + Parses the given identifier to determine the type of identifier and extract relevant components. - This method does not use frame, it is included as it is passed by signal. + The method supports different identifier formats: + - A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/' + - An identifier containing an API key and a model ID separated by an underscore + - An identifier that is solely a model ID of a fixed length + - A local filename that ends with '.pt' or '.yaml' + + Args: + identifier (str): The identifier string to be parsed. + + Returns: + (tuple): A tuple containing the API key, model ID, and filename as applicable. + + Raises: + HUBModelError: If the identifier format is not recognized. """ - if self.alive is True: - LOGGER.info(f'{PREFIX}Kill signal received! ❌') - self._stop_heartbeat() - sys.exit(signum) - def _stop_heartbeat(self): - """Terminate the heartbeat loop.""" - self.alive = False + # Initialize variables + api_key, model_id, filename = None, None, None + + # Check if identifier is a HUB URL + if identifier.startswith(f'{HUB_WEB_ROOT}/models/'): + # Extract the model_id after the HUB_WEB_ROOT URL + model_id = identifier.split(f'{HUB_WEB_ROOT}/models/')[-1] + else: + # Split the identifier based on underscores only if it's not a HUB URL + parts = identifier.split('_') + + # Check if identifier is in the format of API key and model ID + if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20: + api_key, model_id = parts + # Check if identifier is a single model ID + elif len(parts) == 1 and len(parts[0]) == 20: + model_id = parts[0] + # Check if identifier is a local filename + elif identifier.endswith('.pt') or identifier.endswith('.yaml'): + filename = identifier + else: + raise HUBModelError( + f"model='{identifier}' could not be parsed. Check format is correct. " + f'Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file.') + + return api_key, model_id, filename + + def _set_train_args(self, **kwargs): + if self.model.is_trained(): + # Model is already trained + raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀')) + + if self.model.is_resumable(): + # Model has saved weights + self.train_args = {'data': self.model.get_dataset_url(), 'resume': True} + self.model_file = self.model.get_weights_url('last') + else: + # Model has no saved weights + def get_train_args(config): + return { + 'batch': config['batchSize'], + 'epochs': config['epochs'], + 'imgsz': config['imageSize'], + 'patience': config['patience'], + 'device': config['device'], + 'cache': config['cache'], + 'data': self.model.get_dataset_url(), } + + self.train_args = get_train_args(self.model.data.get('config')) + # Set the model file as either a *.pt or *.yaml file + self.model_file = (self.model.get_weights_url('parent') + if self.model.is_pretrained() else self.model.get_architecture()) + + if not self.train_args.get('data'): + raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix + + self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u + self.model_id = self.model.id + + def request_queue( + self, + request_func, + retry=3, + timeout=30, + thread=True, + verbose=True, + progress_total=None, + *args, + **kwargs, + ): + + def retry_request(): + t0 = time.time() # Record the start time for the timeout + for i in range(retry + 1): + if (time.time() - t0) > timeout: + LOGGER.warning(f'{PREFIX}Timeout for request reached. {HELP_MSG}') + break # Timeout reached, exit loop + + response = request_func(*args, **kwargs) + if progress_total: + self._show_upload_progress(progress_total, response) + + if response is None: + LOGGER.warning(f'{PREFIX}Received no response from the request. {HELP_MSG}') + time.sleep(2 ** i) # Exponential backoff before retrying + continue # Skip further processing and retry + + if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES: + return response # Success, no need to retry + + if i == 0: + # Initial attempt, check status code and provide messages + message = self._get_failure_message(response, retry, timeout) + + if verbose: + LOGGER.warning(f'{PREFIX}{message} {HELP_MSG} ({response.status_code})') + + if not self._should_retry(response.status_code): + LOGGER.warning(f'{PREFIX}Request failed. {HELP_MSG} ({response.status_code}') + break # Not an error that should be retried, exit loop + + time.sleep(2 ** i) # Exponential backoff for retries + + return response + + if thread: + # Start a new thread to run the retry_request function + threading.Thread(target=retry_request, daemon=True).start() + else: + # If running in the main thread, call retry_request directly + return retry_request() + + def _should_retry(self, status_code): + # Status codes that trigger retries + retry_codes = { + HTTPStatus.REQUEST_TIMEOUT, + HTTPStatus.BAD_GATEWAY, + HTTPStatus.GATEWAY_TIMEOUT, } + return True if status_code in retry_codes else False + + def _get_failure_message(self, response: requests.Response, retry: int, timeout: int): + """ + Generate a retry message based on the response status code. + + Args: + response: The HTTP response object. + retry: The number of retry attempts allowed. + timeout: The maximum timeout duration. + + Returns: + str: The retry message. + """ + if self._should_retry(response.status_code): + return f'Retrying {retry}x for {timeout}s.' if retry else '' + elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit + headers = response.headers + return (f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). " + f"Please retry after {headers['Retry-After']}s.") + else: + try: + return response.json().get('message', 'No JSON message.') + except AttributeError: + return 'Unable to read JSON.' def upload_metrics(self): """Upload model metrics to Ultralytics HUB.""" - payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'} - smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2) + return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True) - def _get_model(self): - """Fetch and return model data from Ultralytics HUB.""" - api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}' - - try: - response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0) - data = response.json().get('data', None) - - if data.get('status', None) == 'trained': - raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀')) - - if not data.get('data', None): - raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix - self.model_id = data['id'] - - if data['status'] == 'new': # new model to start training - self.train_args = { - 'batch': data['batch_size'], # note HUB argument is slightly different - 'epochs': data['epochs'], - 'imgsz': data['imgsz'], - 'patience': data['patience'], - 'device': data['device'], - 'cache': data['cache'], - 'data': data['data']} - self.model_file = data.get('cfg') or data.get('weights') # cfg for pretrained=False - self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u - elif data['status'] == 'training': # existing model to resume training - self.train_args = {'data': data['data'], 'resume': True} - self.model_file = data['resume'] - - return data - except requests.exceptions.ConnectionError as e: - raise ConnectionRefusedError('ERROR: The HUB server is not online. Please try again later.') from e - except Exception: - raise - - def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): + def upload_model( + self, + epoch: int, + weights: str, + is_best: bool = False, + map: float = 0.0, + final: bool = False, + ) -> None: """ Upload a model checkpoint to Ultralytics HUB. @@ -149,43 +303,33 @@ class HUBTrainingSession: final (bool): Indicates if the model is the final model after training. """ if Path(weights).is_file(): - with open(weights, 'rb') as f: - file = f.read() + progress_total = (Path(weights).stat().st_size if final else None) # Only show progress if final + self.request_queue( + self.model.upload_model, + epoch=epoch, + weights=weights, + is_best=is_best, + map=map, + final=final, + retry=10, + timeout=3600, + thread=not final, + progress_total=progress_total, + ) else: LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.') - file = None - url = f'{self.api_url}/upload' - # url = 'http://httpbin.org/post' # for debug - data = {'epoch': epoch} - if final: - data.update({'type': 'final', 'map': map}) - filesize = Path(weights).stat().st_size - smart_request('post', - url, - data=data, - files={'best.pt': file}, - headers=self.auth_header, - retry=10, - timeout=3600, - thread=False, - progress=filesize, - code=4) - else: - data.update({'type': 'epoch', 'isBest': bool(is_best)}) - smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3) - @threaded - def _start_heartbeat(self): - """Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB.""" - while self.alive: - r = smart_request('post', - f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', - json={ - 'agent': AGENT_NAME, - 'agentId': self.agent_id}, - headers=self.auth_header, - retry=0, - code=5, - thread=False) # already in a thread - self.agent_id = r.json().get('data', {}).get('agentId', None) - sleep(self.rate_limits['heartbeat']) + def _show_upload_progress(self, content_length: int, response: requests.Response) -> None: + """ + Display a progress bar to track the upload progress of a file download. + + Args: + content_length (int): The total size of the content to be downloaded in bytes. + response (requests.Response): The response object from the file download request. + + Returns: + (None) + """ + with TQDM(total=content_length, unit='B', unit_scale=True, unit_divisor=1024) as pbar: + for data in response.iter_content(chunk_size=1024): + pbar.update(len(data)) diff --git a/ultralytics/hub/utils.py b/ultralytics/hub/utils.py index f2621d7a..1277c63a 100644 --- a/ultralytics/hub/utils.py +++ b/ultralytics/hub/utils.py @@ -1,6 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -import os import platform import random import sys @@ -16,8 +15,6 @@ from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES PREFIX = colorstr('Ultralytics HUB: ') HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.' -HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com') -HUB_WEB_ROOT = os.environ.get('ULTRALYTICS_HUB_WEB', 'https://hub.ultralytics.com') def request_with_credentials(url: str) -> any: diff --git a/ultralytics/utils/callbacks/hub.py b/ultralytics/utils/callbacks/hub.py index 7171fb90..f3a3f353 100644 --- a/ultralytics/utils/callbacks/hub.py +++ b/ultralytics/utils/callbacks/hub.py @@ -3,7 +3,9 @@ import json from time import time -from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events +from hub_sdk.config import HUB_WEB_ROOT + +from ultralytics.hub.utils import PREFIX, events from ultralytics.utils import LOGGER, SETTINGS @@ -12,8 +14,9 @@ def on_pretrain_routine_end(trainer): session = getattr(trainer, 'hub_session', None) if session: # Start timer for upload rate limit - LOGGER.info(f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀') - session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit + session.timers = { + 'metrics': time(), + 'ckpt': time(), } # start timer on session.rate_limit def on_fit_epoch_end(trainer): @@ -21,10 +24,13 @@ def on_fit_epoch_end(trainer): session = getattr(trainer, 'hub_session', None) if session: # Upload metrics after val end - all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics} + all_plots = { + **trainer.label_loss_items(trainer.tloss, prefix='train'), + **trainer.metrics, } if trainer.epoch == 0: from ultralytics.utils.torch_utils import model_info_for_loggers all_plots = {**all_plots, **model_info_for_loggers(trainer)} + session.metrics_queue[trainer.epoch] = json.dumps(all_plots) if time() - session.timers['metrics'] > session.rate_limits['metrics']: session.upload_metrics() @@ -39,7 +45,7 @@ def on_model_save(trainer): # Upload checkpoints with rate limiting is_best = trainer.best_fitness == trainer.fitness if time() - session.timers['ckpt'] > session.rate_limits['ckpt']: - LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_id}') + LOGGER.info(f'{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model_file}') session.upload_model(trainer.epoch, trainer.last, is_best) session.timers['ckpt'] = time() # reset timer @@ -50,10 +56,15 @@ def on_train_end(trainer): if session: # Upload final model and metrics with exponential standoff LOGGER.info(f'{PREFIX}Syncing final model...') - session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True) + session.upload_model( + trainer.epoch, + trainer.best, + map=trainer.metrics.get('metrics/mAP50-95(B)', 0), + final=True, + ) session.alive = False # stop heartbeats LOGGER.info(f'{PREFIX}Done ✅\n' - f'{PREFIX}View model at {HUB_WEB_ROOT}/models/{session.model_id} 🚀') + f'{PREFIX}View model at {session.model_url} 🚀') def on_train_start(trainer): @@ -76,7 +87,7 @@ def on_export_start(exporter): events(exporter.args) -callbacks = { +callbacks = ({ 'on_pretrain_routine_end': on_pretrain_routine_end, 'on_fit_epoch_end': on_fit_epoch_end, 'on_model_save': on_model_save, @@ -84,4 +95,4 @@ callbacks = { 'on_train_start': on_train_start, 'on_val_start': on_val_start, 'on_predict_start': on_predict_start, - 'on_export_start': on_export_start} if SETTINGS['hub'] is True else {} # verify enabled + 'on_export_start': on_export_start, } if SETTINGS['hub'] is True else {}) # verify enabled