# Ultralytics YOLO 🚀, AGPL-3.0 license import threading import time from http import HTTPStatus from pathlib import Path import requests from hub_sdk import HUB_WEB_ROOT, HUBClient from ultralytics.hub.utils import HELP_MSG, PREFIX, TQDM from ultralytics.utils import LOGGER, SETTINGS, __version__, checks, emojis, is_colab from ultralytics.utils.errors import HUBModelError AGENT_NAME = (f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local') class HUBTrainingSession: """ HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing. Args: url (str): Model identifier used to initialize the HUB training session. Attributes: agent_id (str): Identifier for the instance communicating with the server. model_id (str): Identifier for the YOLO model being trained. model_url (str): URL for the model in Ultralytics HUB. api_url (str): API URL for the model in Ultralytics HUB. auth_header (dict): Authentication header for the Ultralytics HUB API requests. rate_limits (dict): Rate limits for different API calls (in seconds). timers (dict): Timers for rate limiting. metrics_queue (dict): Queue for the model's metrics. model (dict): Model data fetched from Ultralytics HUB. alive (bool): Indicates if the heartbeat loop is active. """ def __init__(self, identifier): """ Initialize the HUBTrainingSession with the provided model identifier. Args: url (str): Model identifier used to initialize the HUB training session. It can be a URL string or a model key with specific format. Raises: ValueError: If the provided model identifier is invalid. ConnectionError: If connecting with global API key is not supported. """ self.rate_limits = { 'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0, } # rate limits (seconds) self.metrics_queue = {} # holds metrics for each epoch until upload self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py # Parse input api_key, model_id, self.filename = self._parse_identifier(identifier) # Get credentials active_key = api_key or SETTINGS.get('api_key') credentials = {'api_key': active_key} if active_key else None # set credentials # Initialize client self.client = HUBClient(credentials) if model_id: self.load_model(model_id) # load existing model else: self.model = self.client.model() # load empty model def load_model(self, model_id): # Initialize model self.model = self.client.model(model_id) self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}' self._set_train_args() # Start heartbeats for HUB to monitor agent self.model.start_heartbeat(self.rate_limits['heartbeat']) LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀') def create_model(self, model_args): # Initialize model payload = { 'config': { 'batchSize': model_args.get('batch', -1), 'epochs': model_args.get('epochs', 300), 'imageSize': model_args.get('imgsz', 640), 'patience': model_args.get('patience', 100), 'device': model_args.get('device', ''), 'cache': model_args.get('cache', 'ram'), }, 'dataset': { 'name': model_args.get('data')}, 'lineage': { 'architecture': { 'name': self.filename.replace('.pt', '').replace('.yaml', ''), }, 'parent': {}, }, 'meta': { 'name': self.filename}, } if self.filename.endswith('.pt'): payload['lineage']['parent']['name'] = self.filename self.model.create_model(payload) # Model could not be created # TODO: improve error handling if not self.model.id: return self.model_url = f'{HUB_WEB_ROOT}/models/{self.model.id}' # Start heartbeats for HUB to monitor agent self.model.start_heartbeat(self.rate_limits['heartbeat']) LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀') def _parse_identifier(self, identifier): """ Parses the given identifier to determine the type of identifier and extract relevant components. The method supports different identifier formats: - A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/' - An identifier containing an API key and a model ID separated by an underscore - An identifier that is solely a model ID of a fixed length - A local filename that ends with '.pt' or '.yaml' Args: identifier (str): The identifier string to be parsed. Returns: (tuple): A tuple containing the API key, model ID, and filename as applicable. Raises: HUBModelError: If the identifier format is not recognized. """ # Initialize variables api_key, model_id, filename = None, None, None # Check if identifier is a HUB URL if identifier.startswith(f'{HUB_WEB_ROOT}/models/'): # Extract the model_id after the HUB_WEB_ROOT URL model_id = identifier.split(f'{HUB_WEB_ROOT}/models/')[-1] else: # Split the identifier based on underscores only if it's not a HUB URL parts = identifier.split('_') # Check if identifier is in the format of API key and model ID if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20: api_key, model_id = parts # Check if identifier is a single model ID elif len(parts) == 1 and len(parts[0]) == 20: model_id = parts[0] # Check if identifier is a local filename elif identifier.endswith('.pt') or identifier.endswith('.yaml'): filename = identifier else: raise HUBModelError( f"model='{identifier}' could not be parsed. Check format is correct. " f'Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file.') return api_key, model_id, filename def _set_train_args(self, **kwargs): if self.model.is_trained(): # Model is already trained raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀')) if self.model.is_resumable(): # Model has saved weights self.train_args = {'data': self.model.get_dataset_url(), 'resume': True} self.model_file = self.model.get_weights_url('last') else: # Model has no saved weights def get_train_args(config): return { 'batch': config['batchSize'], 'epochs': config['epochs'], 'imgsz': config['imageSize'], 'patience': config['patience'], 'device': config['device'], 'cache': config['cache'], 'data': self.model.get_dataset_url(), } self.train_args = get_train_args(self.model.data.get('config')) # Set the model file as either a *.pt or *.yaml file self.model_file = (self.model.get_weights_url('parent') if self.model.is_pretrained() else self.model.get_architecture()) if not self.train_args.get('data'): raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u self.model_id = self.model.id def request_queue( self, request_func, retry=3, timeout=30, thread=True, verbose=True, progress_total=None, *args, **kwargs, ): def retry_request(): t0 = time.time() # Record the start time for the timeout for i in range(retry + 1): if (time.time() - t0) > timeout: LOGGER.warning(f'{PREFIX}Timeout for request reached. {HELP_MSG}') break # Timeout reached, exit loop response = request_func(*args, **kwargs) if progress_total: self._show_upload_progress(progress_total, response) if response is None: LOGGER.warning(f'{PREFIX}Received no response from the request. {HELP_MSG}') time.sleep(2 ** i) # Exponential backoff before retrying continue # Skip further processing and retry if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES: return response # Success, no need to retry if i == 0: # Initial attempt, check status code and provide messages message = self._get_failure_message(response, retry, timeout) if verbose: LOGGER.warning(f'{PREFIX}{message} {HELP_MSG} ({response.status_code})') if not self._should_retry(response.status_code): LOGGER.warning(f'{PREFIX}Request failed. {HELP_MSG} ({response.status_code}') break # Not an error that should be retried, exit loop time.sleep(2 ** i) # Exponential backoff for retries return response if thread: # Start a new thread to run the retry_request function threading.Thread(target=retry_request, daemon=True).start() else: # If running in the main thread, call retry_request directly return retry_request() def _should_retry(self, status_code): # Status codes that trigger retries retry_codes = { HTTPStatus.REQUEST_TIMEOUT, HTTPStatus.BAD_GATEWAY, HTTPStatus.GATEWAY_TIMEOUT, } return True if status_code in retry_codes else False def _get_failure_message(self, response: requests.Response, retry: int, timeout: int): """ Generate a retry message based on the response status code. Args: response: The HTTP response object. retry: The number of retry attempts allowed. timeout: The maximum timeout duration. Returns: str: The retry message. """ if self._should_retry(response.status_code): return f'Retrying {retry}x for {timeout}s.' if retry else '' elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit headers = response.headers return (f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). " f"Please retry after {headers['Retry-After']}s.") else: try: return response.json().get('message', 'No JSON message.') except AttributeError: return 'Unable to read JSON.' def upload_metrics(self): """Upload model metrics to Ultralytics HUB.""" return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True) def upload_model( self, epoch: int, weights: str, is_best: bool = False, map: float = 0.0, final: bool = False, ) -> None: """ Upload a model checkpoint to Ultralytics HUB. Args: epoch (int): The current training epoch. weights (str): Path to the model weights file. is_best (bool): Indicates if the current model is the best one so far. map (float): Mean average precision of the model. final (bool): Indicates if the model is the final model after training. """ if Path(weights).is_file(): progress_total = (Path(weights).stat().st_size if final else None) # Only show progress if final self.request_queue( self.model.upload_model, epoch=epoch, weights=weights, is_best=is_best, map=map, final=final, retry=10, timeout=3600, thread=not final, progress_total=progress_total, ) else: LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload issue. Missing model {weights}.') def _show_upload_progress(self, content_length: int, response: requests.Response) -> None: """ Display a progress bar to track the upload progress of a file download. Args: content_length (int): The total size of the content to be downloaded in bytes. response (requests.Response): The response object from the file download request. Returns: (None) """ with TQDM(total=content_length, unit='B', unit_scale=True, unit_divisor=1024) as pbar: for data in response.iter_content(chunk_size=1024): pbar.update(len(data))