mirror of
				https://github.com/THU-MIG/yolov10.git
				synced 2025-10-31 14:35:40 +08:00 
			
		
		
		
	8.0.60 new HUB training syntax (#1753)
				
					
				
			Co-authored-by: Rafael Pierre <97888102+rafaelvp-db@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Semih Demirel <85176438+semihhdemirel@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									e7876e1ba9
								
							
						
					
					
						commit
						84948651cd
					
				
							
								
								
									
										28
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										28
									
								
								.github/workflows/ci.yaml
									
									
									
									
										vendored
									
									
								
							| @ -7,7 +7,7 @@ on: | |||||||
|   push: |   push: | ||||||
|     branches: [main] |     branches: [main] | ||||||
|   pull_request: |   pull_request: | ||||||
|     branches: [main] |     branches: [main, updates] | ||||||
|   schedule: |   schedule: | ||||||
|     - cron: '0 0 * * *'  # runs at 00:00 UTC every day |     - cron: '0 0 * * *'  # runs at 00:00 UTC every day | ||||||
| 
 | 
 | ||||||
| @ -43,16 +43,36 @@ jobs: | |||||||
|           python --version |           python --version | ||||||
|           pip --version |           pip --version | ||||||
|           pip list |           pip list | ||||||
|       - name: Test HUB training |       - name: Test HUB training (Python Usage 1) | ||||||
|         shell: python |         shell: python | ||||||
|         env: |         env: | ||||||
|           APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }} |           APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }} | ||||||
|         run: | |         run: | | ||||||
|           import os |           import os | ||||||
|           from ultralytics import hub |           from pathlib import Path | ||||||
|  |           from ultralytics import YOLO, hub | ||||||
|  |           from ultralytics.yolo.utils import USER_CONFIG_DIR | ||||||
|  |           Path(USER_CONFIG_DIR / 'settings.yaml').unlink() | ||||||
|           key = os.environ['APIKEY'] |           key = os.environ['APIKEY'] | ||||||
|           hub.reset_model(key) |           hub.reset_model(key) | ||||||
|           hub.start(key) |           model = YOLO('https://hub.ultralytics.com/models/' + key) | ||||||
|  |           model.train() | ||||||
|  |       - name: Test HUB training (Python Usage 2) | ||||||
|  |         shell: python | ||||||
|  |         env: | ||||||
|  |           APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }} | ||||||
|  |         run: | | ||||||
|  |           import os | ||||||
|  |           from pathlib import Path | ||||||
|  |           from ultralytics import YOLO, hub | ||||||
|  |           from ultralytics.yolo.utils import USER_CONFIG_DIR | ||||||
|  |           Path(USER_CONFIG_DIR / 'settings.yaml').unlink() | ||||||
|  |           key = os.environ['APIKEY'] | ||||||
|  |           hub.reset_model(key) | ||||||
|  |           key, model_id = key.split('_') | ||||||
|  |           hub.login(key) | ||||||
|  |           model = YOLO('https://hub.ultralytics.com/models/' + model_id) | ||||||
|  |           model.train() | ||||||
| 
 | 
 | ||||||
|   Benchmarks: |   Benchmarks: | ||||||
|     runs-on: ${{ matrix.os }} |     runs-on: ${{ matrix.os }} | ||||||
|  | |||||||
| @ -26,6 +26,7 @@ WORKDIR /usr/src/ultralytics | |||||||
| # Copy contents | # Copy contents | ||||||
| # COPY . /usr/src/app  (issues as not a .git directory) | # COPY . /usr/src/app  (issues as not a .git directory) | ||||||
| RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics | RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics | ||||||
|  | ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /usr/src/ultralytics/ | ||||||
| 
 | 
 | ||||||
| # Install pip packages | # Install pip packages | ||||||
| RUN python3 -m pip install --upgrade pip wheel | RUN python3 -m pip install --upgrade pip wheel | ||||||
|  | |||||||
| @ -22,6 +22,7 @@ WORKDIR /usr/src/ultralytics | |||||||
| # Copy contents | # Copy contents | ||||||
| # COPY . /usr/src/app  (issues as not a .git directory) | # COPY . /usr/src/app  (issues as not a .git directory) | ||||||
| RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics | RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics | ||||||
|  | ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /usr/src/ultralytics/ | ||||||
| 
 | 
 | ||||||
| # Install pip packages | # Install pip packages | ||||||
| RUN python3 -m pip install --upgrade pip wheel | RUN python3 -m pip install --upgrade pip wheel | ||||||
|  | |||||||
| @ -22,6 +22,7 @@ WORKDIR /usr/src/ultralytics | |||||||
| # Copy contents | # Copy contents | ||||||
| # COPY . /usr/src/app  (issues as not a .git directory) | # COPY . /usr/src/app  (issues as not a .git directory) | ||||||
| RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics | RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics | ||||||
|  | ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /usr/src/ultralytics/ | ||||||
| 
 | 
 | ||||||
| # Install pip packages | # Install pip packages | ||||||
| RUN python3 -m pip install --upgrade pip wheel | RUN python3 -m pip install --upgrade pip wheel | ||||||
|  | |||||||
| @ -17,7 +17,7 @@ passing `stream=True` in the predictor's call method. | |||||||
|             probs = result.probs  # Class probabilities for classification outputs |             probs = result.probs  # Class probabilities for classification outputs | ||||||
|         ``` |         ``` | ||||||
| 
 | 
 | ||||||
|     === "Return a list with `Stream=True`" |     === "Return a generator with `Stream=True`" | ||||||
|         ```python |         ```python | ||||||
|         inputs = [img, img]  # list of numpy arrays |         inputs = [img, img]  # list of numpy arrays | ||||||
|         results = model(inputs, stream=True)  # generator of Results objects |         results = model(inputs, stream=True)  # generator of Results objects | ||||||
| @ -54,6 +54,40 @@ whether each source can be used in streaming mode with `stream=True` ✅ and an | |||||||
| | YouTube ✅   | `'https://youtu.be/Zgi9g1ksQHc'`           | `str`          |                  | | | YouTube ✅   | `'https://youtu.be/Zgi9g1ksQHc'`           | `str`          |                  | | ||||||
| | stream ✅    | `'rtsp://example.com/media.mp4'`           | `str`          | RTSP, RTMP, HTTP | | | stream ✅    | `'rtsp://example.com/media.mp4'`           | `str`          | RTSP, RTMP, HTTP | | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | ## Arguments | ||||||
|  | `model.predict` accepts multiple arguments that control the predction operation. These arguments can be passed directly to `model.predict`: | ||||||
|  | !!! example | ||||||
|  |     ``` | ||||||
|  |     model.predict(source, save=True, imgsz=320, conf=0.5) | ||||||
|  |     ``` | ||||||
|  | 
 | ||||||
|  | All supported arguments: | ||||||
|  | 
 | ||||||
|  | | Key              | Value                  | Description                                              | | ||||||
|  | |------------------|------------------------|----------------------------------------------------------| | ||||||
|  | | `source`         | `'ultralytics/assets'` | source directory for images or videos                    | | ||||||
|  | | `conf`           | `0.25`                 | object confidence threshold for detection                | | ||||||
|  | | `iou`            | `0.7`                  | intersection over union (IoU) threshold for NMS          | | ||||||
|  | | `half`           | `False`                | use half precision (FP16)                                | | ||||||
|  | | `device`         | `None`                 | device to run on, i.e. cuda device=0/1/2/3 or device=cpu | | ||||||
|  | | `show`           | `False`                | show results if possible                                 | | ||||||
|  | | `save`           | `False`                | save images with results                                 | | ||||||
|  | | `save_txt`       | `False`                | save results as .txt file                                | | ||||||
|  | | `save_conf`      | `False`                | save results with confidence scores                      | | ||||||
|  | | `save_crop`      | `False`                | save cropped images with results                         | | ||||||
|  | | `hide_labels`    | `False`                | hide labels                                              | | ||||||
|  | | `hide_conf`      | `False`                | hide confidence scores                                   | | ||||||
|  | | `max_det`        | `300`                  | maximum number of detections per image                   | | ||||||
|  | | `vid_stride`     | `False`                | video frame-rate stride                                  | | ||||||
|  | | `line_thickness` | `3`                    | bounding box thickness (pixels)                          | | ||||||
|  | | `visualize`      | `False`                | visualize model features                                 | | ||||||
|  | | `augment`        | `False`                | apply image augmentation to prediction sources           | | ||||||
|  | | `agnostic_nms`   | `False`                | class-agnostic NMS                                       | | ||||||
|  | | `retina_masks`   | `False`                | use high-resolution segmentation masks                   | | ||||||
|  | | `classes`        | `None`                 | filter results by class, i.e. class=0, or class=[0,2,3]  | | ||||||
|  | | `boxes`          | `True`                 | Show boxes in segmentation predictions                   | | ||||||
|  | 
 | ||||||
| ## Image and Video Formats | ## Image and Video Formats | ||||||
| 
 | 
 | ||||||
| YOLOv8 supports various image and video formats, as specified | YOLOv8 supports various image and video formats, as specified | ||||||
|  | |||||||
| @ -96,7 +96,6 @@ names: | |||||||
|   77: teddy bear |   77: teddy bear | ||||||
|   78: hair drier |   78: hair drier | ||||||
|   79: toothbrush |   79: toothbrush | ||||||
| 
 |  | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -6,7 +6,7 @@ Just simply clone and run | |||||||
| 
 | 
 | ||||||
| ```bash | ```bash | ||||||
| pip install -r requirements.txt | pip install -r requirements.txt | ||||||
| python main.py | python main.py --model yolov8n.onnx --img image.jpg | ||||||
| ``` | ``` | ||||||
| 
 | 
 | ||||||
| If you start from scratch: | If you start from scratch: | ||||||
|  | |||||||
| @ -1,3 +1,5 @@ | |||||||
|  | import argparse | ||||||
|  | 
 | ||||||
| import cv2.dnn | import cv2.dnn | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
| @ -16,9 +18,9 @@ def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h): | |||||||
|     cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) |     cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def main(): | def main(onnx_model, input_image): | ||||||
|     model: cv2.dnn.Net = cv2.dnn.readNetFromONNX('yolov8n.onnx') |     model: cv2.dnn.Net = cv2.dnn.readNetFromONNX(onnx_model) | ||||||
|     original_image: np.ndarray = cv2.imread(str(ROOT / 'assets/bus.jpg')) |     original_image: np.ndarray = cv2.imread(input_image) | ||||||
|     [height, width, _] = original_image.shape |     [height, width, _] = original_image.shape | ||||||
|     length = max((height, width)) |     length = max((height, width)) | ||||||
|     image = np.zeros((length, length, 3), np.uint8) |     image = np.zeros((length, length, 3), np.uint8) | ||||||
| @ -71,4 +73,8 @@ def main(): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     main() |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument('--model', default='yolov8n.onnx', help='Input your onnx model.') | ||||||
|  |     parser.add_argument('--img', default=str(ROOT / 'assets/bus.jpg'), help='Path to input image.') | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     main(args.model, args.img) | ||||||
|  | |||||||
| @ -46,7 +46,7 @@ theme: | |||||||
|     - content.tabs.link  # all code tabs change simultaneously |     - content.tabs.link  # all code tabs change simultaneously | ||||||
| 
 | 
 | ||||||
| # Customization | # Customization | ||||||
| copyright: Ultralytics 2023. All rights reserved. | copyright: <a href="https://ultralytics.com" target="_blank">Ultralytics 2023.</a> All rights reserved. | ||||||
| extra: | extra: | ||||||
|   # version: |   # version: | ||||||
|   #   provider: mike  #  version drop-down menu |   #   provider: mike  #  version drop-down menu | ||||||
| @ -167,7 +167,7 @@ nav: | |||||||
|       - Hyperparameter evolution: yolov5/hyp_evolution.md |       - Hyperparameter evolution: yolov5/hyp_evolution.md | ||||||
|       - Transfer learning with frozen layers: yolov5/transfer_learn_frozen.md |       - Transfer learning with frozen layers: yolov5/transfer_learn_frozen.md | ||||||
|       - Architecture Summary: yolov5/architecture.md |       - Architecture Summary: yolov5/architecture.md | ||||||
|       - Roboflow for Datasets, Labeling, and Active Learning: yolov5/roboflow.md |       - Roboflow Datasets: yolov5/roboflow.md | ||||||
|       - Neural Magic's DeepSparse: yolov5/neural_magic.md |       - Neural Magic's DeepSparse: yolov5/neural_magic.md | ||||||
|       - Comet Logging: yolov5/comet.md |       - Comet Logging: yolov5/comet.md | ||||||
|       - Clearml Logging: yolov5/clearml.md |       - Clearml Logging: yolov5/clearml.md | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @ -58,7 +58,7 @@ setup( | |||||||
|         'Topic :: Scientific/Engineering :: Artificial Intelligence', |         'Topic :: Scientific/Engineering :: Artificial Intelligence', | ||||||
|         'Topic :: Scientific/Engineering :: Image Recognition', |         'Topic :: Scientific/Engineering :: Image Recognition', | ||||||
|         'Operating System :: POSIX :: Linux', |         'Operating System :: POSIX :: Linux', | ||||||
|         'Operating System :: macOS', |         'Operating System :: MacOS', | ||||||
|         'Operating System :: Microsoft :: Windows', ], |         'Operating System :: Microsoft :: Windows', ], | ||||||
|     keywords='machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics', |     keywords='machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics', | ||||||
|     entry_points={ |     entry_points={ | ||||||
|  | |||||||
| @ -56,11 +56,11 @@ def test_predict_detect(): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_predict_segment(): | def test_predict_segment(): | ||||||
|     run(f"yolo predict model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32 save") |     run(f"yolo predict model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32 save save_txt") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_predict_classify(): | def test_predict_classify(): | ||||||
|     run(f"yolo predict model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32 save") |     run(f"yolo predict model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32 save save_txt") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # Export checks -------------------------------------------------------------------------------------------------------- | # Export checks -------------------------------------------------------------------------------------------------------- | ||||||
|  | |||||||
| @ -1,8 +1,9 @@ | |||||||
| # Ultralytics YOLO 🚀, GPL-3.0 license | # Ultralytics YOLO 🚀, GPL-3.0 license | ||||||
| 
 | 
 | ||||||
| __version__ = '8.0.59' | __version__ = '8.0.60' | ||||||
| 
 | 
 | ||||||
|  | from ultralytics.hub import start | ||||||
| from ultralytics.yolo.engine.model import YOLO | from ultralytics.yolo.engine.model import YOLO | ||||||
| from ultralytics.yolo.utils.checks import check_yolo as checks | from ultralytics.yolo.utils.checks import check_yolo as checks | ||||||
| 
 | 
 | ||||||
| __all__ = '__version__', 'YOLO', 'checks'  # allow simpler import | __all__ = '__version__', 'YOLO', 'checks', 'start'  # allow simpler import | ||||||
|  | |||||||
| @ -2,47 +2,51 @@ | |||||||
| 
 | 
 | ||||||
| import requests | import requests | ||||||
| 
 | 
 | ||||||
| from ultralytics.hub.auth import Auth |  | ||||||
| from ultralytics.hub.session import HUBTrainingSession |  | ||||||
| from ultralytics.hub.utils import PREFIX, split_key | from ultralytics.hub.utils import PREFIX, split_key | ||||||
| from ultralytics.yolo.engine.model import YOLO | from ultralytics.yolo.utils import LOGGER | ||||||
| from ultralytics.yolo.utils import LOGGER, emojis | 
 | ||||||
|  | 
 | ||||||
|  | def login(api_key=''): | ||||||
|  |     """ | ||||||
|  |     Log in to the Ultralytics HUB API using the provided API key. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  |         from ultralytics import hub | ||||||
|  |         hub.login('your_api_key') | ||||||
|  |     """ | ||||||
|  |     from ultralytics.hub.auth import Auth | ||||||
|  |     Auth(api_key) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def logout(): | ||||||
|  |     """ | ||||||
|  |     Logout Ultralytics HUB | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  |         from ultralytics import hub | ||||||
|  |         hub.logout() | ||||||
|  |     """ | ||||||
|  |     LOGGER.warning('WARNING ⚠️ This method is not yet implemented.') | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def start(key=''): | def start(key=''): | ||||||
|     """ |     """ | ||||||
|     Start training models with Ultralytics HUB. Usage: from ultralytics.hub import start; start('API_KEY') |     Start training models with Ultralytics HUB (DEPRECATED). | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         key (str, optional): A string containing either the API key and model ID combination (apikey_modelid), | ||||||
|  |                                or the full model URL (https://hub.ultralytics.com/models/apikey_modelid). | ||||||
|     """ |     """ | ||||||
|     auth = Auth(key) |     LOGGER.warning(f""" | ||||||
|     model_id = split_key(key)[1] if auth.get_state() else request_api_key(auth) | WARNING ⚠️ ultralytics.start() is deprecated in 8.0.60. Updated usage to train your Ultralytics HUB model is below: | ||||||
|     if not model_id: |  | ||||||
|         raise ConnectionError(emojis('Connecting with global API key is not currently supported. ❌')) |  | ||||||
| 
 | 
 | ||||||
|     session = HUBTrainingSession(model_id=model_id, auth=auth) | from ultralytics import YOLO | ||||||
|     session.check_disk_space() |  | ||||||
| 
 | 
 | ||||||
|     model = YOLO(model=session.model_file, session=session) | model = YOLO('https://hub.ultralytics.com/models/{key}') | ||||||
|     model.train(**session.train_args) | model.train()""") | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def request_api_key(auth, max_attempts=3): |  | ||||||
|     """ |  | ||||||
|     Prompt the user to input their API key. Returns the model ID. |  | ||||||
|     """ |  | ||||||
|     import getpass |  | ||||||
|     for attempts in range(max_attempts): |  | ||||||
|         LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}') |  | ||||||
|         input_key = getpass.getpass( |  | ||||||
|             'Enter your Ultralytics API Key from https://hub.ultralytics.com/settings?tab=api+keys:\n') |  | ||||||
|         auth.api_key, model_id = split_key(input_key) |  | ||||||
| 
 |  | ||||||
|         if auth.authenticate(): |  | ||||||
|             LOGGER.info(f'{PREFIX}Authenticated ✅') |  | ||||||
|             return model_id |  | ||||||
| 
 |  | ||||||
|         LOGGER.warning(f'{PREFIX}Invalid API key ⚠️\n') |  | ||||||
| 
 |  | ||||||
|     raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌')) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def reset_model(key=''): | def reset_model(key=''): | ||||||
|  | |||||||
| @ -2,27 +2,74 @@ | |||||||
| 
 | 
 | ||||||
| import requests | import requests | ||||||
| 
 | 
 | ||||||
| from ultralytics.hub.utils import HUB_API_ROOT, request_with_credentials | from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, request_with_credentials | ||||||
| from ultralytics.yolo.utils import is_colab | from ultralytics.yolo.utils import LOGGER, SETTINGS, emojis, is_colab, set_settings | ||||||
| 
 | 
 | ||||||
| API_KEY_PATH = 'https://hub.ultralytics.com/settings?tab=api+keys' | API_KEY_URL = 'https://hub.ultralytics.com/settings?tab=api+keys' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Auth: | class Auth: | ||||||
|     id_token = api_key = model_key = False |     id_token = api_key = model_key = False | ||||||
| 
 | 
 | ||||||
|     def __init__(self, api_key=None): |     def __init__(self, api_key=''): | ||||||
|         self.api_key = self._clean_api_key(api_key) |         """ | ||||||
|         self.authenticate() if self.api_key else self.auth_with_cookies() |         Initialize the Auth class with an optional API key. | ||||||
| 
 | 
 | ||||||
|     @staticmethod |         Args: | ||||||
|     def _clean_api_key(key: str) -> str: |             api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id | ||||||
|         """Strip model from key if present""" |         """ | ||||||
|         separator = '_' |         # Split the input API key in case it contains a combined key_model and keep only the API key part | ||||||
|         return key.split(separator)[0] if separator in key else key |         api_key = api_key.split('_')[0] | ||||||
|  | 
 | ||||||
|  |         # Set API key attribute as value passed or SETTINGS API key if none passed | ||||||
|  |         self.api_key = api_key or SETTINGS.get('api_key', '') | ||||||
|  | 
 | ||||||
|  |         # If an API key is provided | ||||||
|  |         if self.api_key: | ||||||
|  |             # If the provided API key matches the API key in the SETTINGS | ||||||
|  |             if self.api_key == SETTINGS.get('api_key'): | ||||||
|  |                 # Log that the user is already logged in | ||||||
|  |                 LOGGER.info(f'{PREFIX}Authenticated ✅') | ||||||
|  |                 return | ||||||
|  |             else: | ||||||
|  |                 # Attempt to authenticate with the provided API key | ||||||
|  |                 success = self.authenticate() | ||||||
|  |         # If the API key is not provided and the environment is a Google Colab notebook | ||||||
|  |         elif is_colab(): | ||||||
|  |             # Attempt to authenticate using browser cookies | ||||||
|  |             success = self.auth_with_cookies() | ||||||
|  |         else: | ||||||
|  |             # Request an API key | ||||||
|  |             success = self.request_api_key() | ||||||
|  | 
 | ||||||
|  |         # Update SETTINGS with the new API key after successful authentication | ||||||
|  |         if success: | ||||||
|  |             set_settings({'api_key': self.api_key}) | ||||||
|  |             # Log that the new login was successful | ||||||
|  |             LOGGER.info(f'{PREFIX}New authentication successful ✅') | ||||||
|  |         else: | ||||||
|  |             LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}') | ||||||
|  | 
 | ||||||
|  |     def request_api_key(self, max_attempts=3): | ||||||
|  |         """ | ||||||
|  |         Prompt the user to input their API key. Returns the model ID. | ||||||
|  |         """ | ||||||
|  |         import getpass | ||||||
|  |         for attempts in range(max_attempts): | ||||||
|  |             LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}') | ||||||
|  |             input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ') | ||||||
|  |             self.api_key = input_key.split('_')[0]  # remove model id if present | ||||||
|  |             if self.authenticate(): | ||||||
|  |                 return True | ||||||
|  |         raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌')) | ||||||
| 
 | 
 | ||||||
|     def authenticate(self) -> bool: |     def authenticate(self) -> bool: | ||||||
|         """Attempt to authenticate with server""" |         """ | ||||||
|  |         Attempt to authenticate with the server using either id_token or API key. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             bool: True if authentication is successful, False otherwise. | ||||||
|  |         """ | ||||||
|         try: |         try: | ||||||
|             header = self.get_auth_header() |             header = self.get_auth_header() | ||||||
|             if header: |             if header: | ||||||
| @ -33,12 +80,16 @@ class Auth: | |||||||
|             raise ConnectionError('User has not authenticated locally.') |             raise ConnectionError('User has not authenticated locally.') | ||||||
|         except ConnectionError: |         except ConnectionError: | ||||||
|             self.id_token = self.api_key = False  # reset invalid |             self.id_token = self.api_key = False  # reset invalid | ||||||
|  |             LOGGER.warning(f'{PREFIX}Invalid API key ⚠️') | ||||||
|             return False |             return False | ||||||
| 
 | 
 | ||||||
|     def auth_with_cookies(self) -> bool: |     def auth_with_cookies(self) -> bool: | ||||||
|         """ |         """ | ||||||
|         Attempt to fetch authentication via cookies and set id_token. |         Attempt to fetch authentication via cookies and set id_token. | ||||||
|         User must be logged in to HUB and running in a supported browser. |         User must be logged in to HUB and running in a supported browser. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             bool: True if authentication is successful, False otherwise. | ||||||
|         """ |         """ | ||||||
|         if not is_colab(): |         if not is_colab(): | ||||||
|             return False  # Currently only works with Colab |             return False  # Currently only works with Colab | ||||||
| @ -54,6 +105,12 @@ class Auth: | |||||||
|             return False |             return False | ||||||
| 
 | 
 | ||||||
|     def get_auth_header(self): |     def get_auth_header(self): | ||||||
|  |         """ | ||||||
|  |         Get the authentication header for making API requests. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             dict: The authentication header if id_token or API key is set, None otherwise. | ||||||
|  |         """ | ||||||
|         if self.id_token: |         if self.id_token: | ||||||
|             return {'authorization': f'Bearer {self.id_token}'} |             return {'authorization': f'Bearer {self.id_token}'} | ||||||
|         elif self.api_key: |         elif self.api_key: | ||||||
| @ -62,9 +119,19 @@ class Auth: | |||||||
|             return None |             return None | ||||||
| 
 | 
 | ||||||
|     def get_state(self) -> bool: |     def get_state(self) -> bool: | ||||||
|         """Get the authentication state""" |         """ | ||||||
|  |         Get the authentication state. | ||||||
|  | 
 | ||||||
|  |         Returns: | ||||||
|  |             bool: True if either id_token or API key is set, False otherwise. | ||||||
|  |         """ | ||||||
|         return self.id_token or self.api_key |         return self.id_token or self.api_key | ||||||
| 
 | 
 | ||||||
|     def set_api_key(self, key: str): |     def set_api_key(self, key: str): | ||||||
|         """Get the authentication state""" |         """ | ||||||
|  |         Set the API key for authentication. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             key (str): The API key string. | ||||||
|  |         """ | ||||||
|         self.api_key = key |         self.api_key = key | ||||||
|  | |||||||
| @ -6,17 +6,62 @@ from time import sleep | |||||||
| 
 | 
 | ||||||
| import requests | import requests | ||||||
| 
 | 
 | ||||||
| from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request | from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, check_dataset_disk_space, smart_request | ||||||
| from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, checks, emojis, is_colab, threaded | from ultralytics.yolo.utils import LOGGER, __version__, checks, emojis, is_colab, threaded | ||||||
| 
 | 
 | ||||||
| AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local' | AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class HUBTrainingSession: | class HUBTrainingSession: | ||||||
|  |     """ | ||||||
|  |     HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing. | ||||||
| 
 | 
 | ||||||
|     def __init__(self, model_id, auth): |     Args: | ||||||
|  |         url (str): Model identifier used to initialize the HUB training session. | ||||||
|  | 
 | ||||||
|  |     Attributes: | ||||||
|  |         agent_id (str): Identifier for the instance communicating with the server. | ||||||
|  |         model_id (str): Identifier for the YOLOv5 model being trained. | ||||||
|  |         model_url (str): URL for the model in Ultralytics HUB. | ||||||
|  |         api_url (str): API URL for the model in Ultralytics HUB. | ||||||
|  |         auth_header (Dict): Authentication header for the Ultralytics HUB API requests. | ||||||
|  |         rate_limits (Dict): Rate limits for different API calls (in seconds). | ||||||
|  |         timers (Dict): Timers for rate limiting. | ||||||
|  |         metrics_queue (Dict): Queue for the model's metrics. | ||||||
|  |         model (Dict): Model data fetched from Ultralytics HUB. | ||||||
|  |         alive (bool): Indicates if the heartbeat loop is active. | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, url): | ||||||
|  |         """ | ||||||
|  |         Initialize the HUBTrainingSession with the provided model identifier. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             url (str): Model identifier used to initialize the HUB training session. | ||||||
|  |                          It can be a URL string or a model key with specific format. | ||||||
|  | 
 | ||||||
|  |         Raises: | ||||||
|  |             ValueError: If the provided model identifier is invalid. | ||||||
|  |             ConnectionError: If connecting with global API key is not supported. | ||||||
|  |         """ | ||||||
|  | 
 | ||||||
|  |         from ultralytics.hub.auth import Auth | ||||||
|  | 
 | ||||||
|  |         # Parse input | ||||||
|  |         if url.startswith('https://hub.ultralytics.com/models/'): | ||||||
|  |             url = url.split('https://hub.ultralytics.com/models/')[-1] | ||||||
|  |         if [len(x) for x in url.split('_')] == [42, 20]: | ||||||
|  |             key, model_id = url.split('_') | ||||||
|  |         elif len(url) == 20: | ||||||
|  |             key, model_id = '', url | ||||||
|  |         else: | ||||||
|  |             raise ValueError(f'Invalid HUBTrainingSession input: {url}') | ||||||
|  | 
 | ||||||
|  |         # Authorize | ||||||
|  |         auth = Auth(key) | ||||||
|         self.agent_id = None  # identifies which instance is communicating with server |         self.agent_id = None  # identifies which instance is communicating with server | ||||||
|         self.model_id = model_id |         self.model_id = model_id | ||||||
|  |         self.model_url = f'https://hub.ultralytics.com/models/{model_id}' | ||||||
|         self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}' |         self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}' | ||||||
|         self.auth_header = auth.get_auth_header() |         self.auth_header = auth.get_auth_header() | ||||||
|         self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0}  # rate limits (seconds) |         self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0}  # rate limits (seconds) | ||||||
| @ -26,16 +71,17 @@ class HUBTrainingSession: | |||||||
|         self.alive = True |         self.alive = True | ||||||
|         self._start_heartbeat()  # start heartbeats |         self._start_heartbeat()  # start heartbeats | ||||||
|         self._register_signal_handlers() |         self._register_signal_handlers() | ||||||
|  |         LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀') | ||||||
| 
 | 
 | ||||||
|     def _register_signal_handlers(self): |     def _register_signal_handlers(self): | ||||||
|  |         """Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination.""" | ||||||
|         signal.signal(signal.SIGTERM, self._handle_signal) |         signal.signal(signal.SIGTERM, self._handle_signal) | ||||||
|         signal.signal(signal.SIGINT, self._handle_signal) |         signal.signal(signal.SIGINT, self._handle_signal) | ||||||
| 
 | 
 | ||||||
|     def _handle_signal(self, signum, frame): |     def _handle_signal(self, signum, frame): | ||||||
|         """ |         """ | ||||||
|         Prevent heartbeats from being sent on Colab after kill. |         Handle kill signals and prevent heartbeats from being sent on Colab after termination. | ||||||
|         This method does not use frame, it is included as it is |         This method does not use frame, it is included as it is passed by signal. | ||||||
|         passed by signal. |  | ||||||
|         """ |         """ | ||||||
|         if self.alive is True: |         if self.alive is True: | ||||||
|             LOGGER.info(f'{PREFIX}Kill signal received! ❌') |             LOGGER.info(f'{PREFIX}Kill signal received! ❌') | ||||||
| @ -43,15 +89,16 @@ class HUBTrainingSession: | |||||||
|             sys.exit(signum) |             sys.exit(signum) | ||||||
| 
 | 
 | ||||||
|     def _stop_heartbeat(self): |     def _stop_heartbeat(self): | ||||||
|         """End the heartbeat loop""" |         """Terminate the heartbeat loop.""" | ||||||
|         self.alive = False |         self.alive = False | ||||||
| 
 | 
 | ||||||
|     def upload_metrics(self): |     def upload_metrics(self): | ||||||
|  |         """Upload model metrics to Ultralytics HUB.""" | ||||||
|         payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'} |         payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'} | ||||||
|         smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2) |         smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2) | ||||||
| 
 | 
 | ||||||
|     def _get_model(self): |     def _get_model(self): | ||||||
|         # Returns model from database by id |         """Fetch and return model data from Ultralytics HUB.""" | ||||||
|         api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}' |         api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}' | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
| @ -59,9 +106,7 @@ class HUBTrainingSession: | |||||||
|             data = response.json().get('data', None) |             data = response.json().get('data', None) | ||||||
| 
 | 
 | ||||||
|             if data.get('status', None) == 'trained': |             if data.get('status', None) == 'trained': | ||||||
|                 raise ValueError( |                 raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀')) | ||||||
|                     emojis(f'Model is already trained and uploaded to ' |  | ||||||
|                            f'https://hub.ultralytics.com/models/{self.model_id} 🚀')) |  | ||||||
| 
 | 
 | ||||||
|             if not data.get('data', None): |             if not data.get('data', None): | ||||||
|                 raise ValueError('Dataset may still be processing. Please wait a minute and try again.')  # RF fix |                 raise ValueError('Dataset may still be processing. Please wait a minute and try again.')  # RF fix | ||||||
| @ -88,11 +133,21 @@ class HUBTrainingSession: | |||||||
|             raise |             raise | ||||||
| 
 | 
 | ||||||
|     def check_disk_space(self): |     def check_disk_space(self): | ||||||
|         if not check_dataset_disk_space(self.model['data']): |         """Check if there is enough disk space for the dataset.""" | ||||||
|  |         if not check_dataset_disk_space(url=self.model['data']): | ||||||
|             raise MemoryError('Not enough disk space') |             raise MemoryError('Not enough disk space') | ||||||
| 
 | 
 | ||||||
|     def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): |     def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False): | ||||||
|         # Upload a model to HUB |         """ | ||||||
|  |         Upload a model checkpoint to Ultralytics HUB. | ||||||
|  | 
 | ||||||
|  |         Args: | ||||||
|  |             epoch (int): The current training epoch. | ||||||
|  |             weights (str): Path to the model weights file. | ||||||
|  |             is_best (bool): Indicates if the current model is the best one so far. | ||||||
|  |             map (float): Mean average precision of the model. | ||||||
|  |             final (bool): Indicates if the model is the final model after training. | ||||||
|  |         """ | ||||||
|         if Path(weights).is_file(): |         if Path(weights).is_file(): | ||||||
|             with open(weights, 'rb') as f: |             with open(weights, 'rb') as f: | ||||||
|                 file = f.read() |                 file = f.read() | ||||||
| @ -120,6 +175,7 @@ class HUBTrainingSession: | |||||||
| 
 | 
 | ||||||
|     @threaded |     @threaded | ||||||
|     def _start_heartbeat(self): |     def _start_heartbeat(self): | ||||||
|  |         """Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB.""" | ||||||
|         while self.alive: |         while self.alive: | ||||||
|             r = smart_request('post', |             r = smart_request('post', | ||||||
|                               f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', |                               f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}', | ||||||
|  | |||||||
| @ -22,7 +22,16 @@ HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.co | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0): | def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0): | ||||||
|     # Check that url fits on disk with safety factor sf, i.e. require 2GB free if url size is 1GB with sf=2.0 |     """ | ||||||
|  |     Check if there is sufficient disk space to download and store a dataset. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         url (str, optional): The URL to the dataset file. Defaults to 'https://ultralytics.com/assets/coco128.zip'. | ||||||
|  |         sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 2.0. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         bool: True if there is sufficient disk space, False otherwise. | ||||||
|  |     """ | ||||||
|     gib = 1 << 30  # bytes per GiB |     gib = 1 << 30  # bytes per GiB | ||||||
|     data = int(requests.head(url).headers['Content-Length']) / gib  # dataset size (GB) |     data = int(requests.head(url).headers['Content-Length']) / gib  # dataset size (GB) | ||||||
|     total, used, free = (x / gib for x in shutil.disk_usage('/'))  # bytes |     total, used, free = (x / gib for x in shutil.disk_usage('/'))  # bytes | ||||||
| @ -35,7 +44,18 @@ def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', s | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def request_with_credentials(url: str) -> any: | def request_with_credentials(url: str) -> any: | ||||||
|     """ Make an ajax request with cookies attached """ |     """ | ||||||
|  |     Make an AJAX request with cookies attached in a Google Colab environment. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         url (str): The URL to make the request to. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         any: The response data from the AJAX request. | ||||||
|  | 
 | ||||||
|  |     Raises: | ||||||
|  |         OSError: If the function is not run in a Google Colab environment. | ||||||
|  |     """ | ||||||
|     if not is_colab(): |     if not is_colab(): | ||||||
|         raise OSError('request_with_credentials() must run in a Colab environment') |         raise OSError('request_with_credentials() must run in a Colab environment') | ||||||
|     from google.colab import output  # noqa |     from google.colab import output  # noqa | ||||||
| @ -95,7 +115,6 @@ def requests_with_progress(method, url, **kwargs): | |||||||
| 
 | 
 | ||||||
|     Returns: |     Returns: | ||||||
|         requests.Response: The response from the HTTP request. |         requests.Response: The response from the HTTP request. | ||||||
| 
 |  | ||||||
|     """ |     """ | ||||||
|     progress = kwargs.pop('progress', False) |     progress = kwargs.pop('progress', False) | ||||||
|     if not progress: |     if not progress: | ||||||
| @ -126,7 +145,6 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos | |||||||
| 
 | 
 | ||||||
|     Returns: |     Returns: | ||||||
|         requests.Response: The HTTP response object. If the request is executed in a separate thread, returns None. |         requests.Response: The HTTP response object. If the request is executed in a separate thread, returns None. | ||||||
| 
 |  | ||||||
|     """ |     """ | ||||||
|     retry_codes = (408, 500)  # retry only these codes |     retry_codes = (408, 500)  # retry only these codes | ||||||
| 
 | 
 | ||||||
| @ -171,8 +189,8 @@ class Traces: | |||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         """ |         """ | ||||||
|         Initialize Traces for error tracking and reporting if tests are not currently running. |         Initialize Traces for error tracking and reporting if tests are not currently running. | ||||||
|  |         Sets the rate limit, timer, and metadata attributes, and determines whether Traces are enabled. | ||||||
|         """ |         """ | ||||||
|         from ultralytics.yolo.cfg import MODES, TASKS |  | ||||||
|         self.rate_limit = 60.0  # rate limit (seconds) |         self.rate_limit = 60.0  # rate limit (seconds) | ||||||
|         self.t = 0.0  # rate limit timer (seconds) |         self.t = 0.0  # rate limit timer (seconds) | ||||||
|         self.metadata = { |         self.metadata = { | ||||||
| @ -187,17 +205,22 @@ class Traces: | |||||||
|             not TESTS_RUNNING and \ |             not TESTS_RUNNING and \ | ||||||
|             ONLINE and \ |             ONLINE and \ | ||||||
|             (is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git') |             (is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git') | ||||||
|         self.usage = {'tasks': {k: 0 for k in TASKS}, 'modes': {k: 0 for k in MODES}} |         self._reset_usage() | ||||||
| 
 | 
 | ||||||
|     def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0): |     def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0): | ||||||
|         """ |         """ | ||||||
|        Sync traces data if enabled in the global settings |         Sync traces data if enabled in the global settings. | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
|             cfg (IterableSimpleNamespace): Configuration for the task and mode. |             cfg (IterableSimpleNamespace): Configuration for the task and mode. | ||||||
|             all_keys (bool): Sync all items, not just non-default values. |             all_keys (bool): Sync all items, not just non-default values. | ||||||
|             traces_sample_rate (float): Fraction of traces captured from 0.0 to 1.0 |             traces_sample_rate (float): Fraction of traces captured from 0.0 to 1.0. | ||||||
|         """ |         """ | ||||||
|  | 
 | ||||||
|  |         # Increment usage | ||||||
|  |         self.usage['modes'][cfg.mode] = self.usage['modes'].get(cfg.mode, 0) + 1 | ||||||
|  |         self.usage['tasks'][cfg.task] = self.usage['tasks'].get(cfg.task, 0) + 1 | ||||||
|  | 
 | ||||||
|         t = time.time()  # current time |         t = time.time()  # current time | ||||||
|         if not self.enabled or random() > traces_sample_rate: |         if not self.enabled or random() > traces_sample_rate: | ||||||
|             # Traces disabled or not randomly selected, do nothing |             # Traces disabled or not randomly selected, do nothing | ||||||
| @ -207,18 +230,20 @@ class Traces: | |||||||
|             return |             return | ||||||
|         else: |         else: | ||||||
|             # Time is over rate limiter, send trace now |             # Time is over rate limiter, send trace now | ||||||
|             self.t = t  # reset rate limit timer |             trace = {'uuid': SETTINGS['uuid'], 'usage': self.usage.copy(), 'metadata': self.metadata} | ||||||
| 
 |  | ||||||
|             # Build trace |  | ||||||
|             if cfg.task in self.usage['tasks']: |  | ||||||
|                 self.usage['tasks'][cfg.task] += 1 |  | ||||||
|             if cfg.mode in self.usage['modes']: |  | ||||||
|                 self.usage['modes'][cfg.mode] += 1 |  | ||||||
|             trace = {'uuid': SETTINGS['uuid'], 'usage': self.usage, 'metadata': self.metadata} |  | ||||||
| 
 | 
 | ||||||
|             # Send a request to the HUB API to sync analytics |             # Send a request to the HUB API to sync analytics | ||||||
|             smart_request('post', f'{HUB_API_ROOT}/v1/usage/anonymous', json=trace, code=3, retry=0, verbose=False) |             smart_request('post', f'{HUB_API_ROOT}/v1/usage/anonymous', json=trace, code=3, retry=0, verbose=False) | ||||||
| 
 | 
 | ||||||
|  |             # Reset usage and rate limit timer | ||||||
|  |             self._reset_usage() | ||||||
|  |             self.t = t | ||||||
|  | 
 | ||||||
|  |     def _reset_usage(self): | ||||||
|  |         """Reset the usage dictionary by initializing keys for each task and mode with a value of 0.""" | ||||||
|  |         from ultralytics.yolo.cfg import MODES, TASKS | ||||||
|  |         self.usage = {'tasks': {k: 0 for k in TASKS}, 'modes': {k: 0 for k in MODES}} | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| # Run below code on hub/utils init ------------------------------------------------------------------------------------- | # Run below code on hub/utils init ------------------------------------------------------------------------------------- | ||||||
| traces = Traces() | traces = Traces() | ||||||
|  | |||||||
| @ -9,7 +9,8 @@ from types import SimpleNamespace | |||||||
| from typing import Dict, List, Union | from typing import Dict, List, Union | ||||||
| 
 | 
 | ||||||
| from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR, | from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR, | ||||||
|                                     IterableSimpleNamespace, __version__, checks, colorstr, yaml_load, yaml_print) |                                     IterableSimpleNamespace, __version__, checks, colorstr, get_settings, yaml_load, | ||||||
|  |                                     yaml_print) | ||||||
| 
 | 
 | ||||||
| # Define valid tasks and modes | # Define valid tasks and modes | ||||||
| MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark' | MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark' | ||||||
| @ -187,6 +188,51 @@ def merge_equals_args(args: List[str]) -> List[str]: | |||||||
|     return new_args |     return new_args | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def handle_yolo_hub(args: List[str]) -> None: | ||||||
|  |     """ | ||||||
|  |     Handle Ultralytics HUB command-line interface (CLI) commands. | ||||||
|  | 
 | ||||||
|  |     This function processes Ultralytics HUB CLI commands such as login and logout. | ||||||
|  |     It should be called when executing a script with arguments related to HUB authentication. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         args (List[str]): A list of command line arguments | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  |         python my_script.py hub login your_api_key | ||||||
|  |     """ | ||||||
|  |     from ultralytics import hub | ||||||
|  | 
 | ||||||
|  |     if args[0] == 'login': | ||||||
|  |         key = args[1] if len(args) > 1 else '' | ||||||
|  |         # Log in to Ultralytics HUB using the provided API key | ||||||
|  |         hub.login(key) | ||||||
|  |     elif args[0] == 'logout': | ||||||
|  |         # Log out from Ultralytics HUB | ||||||
|  |         hub.logout() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def handle_yolo_settings(args: List[str]) -> None: | ||||||
|  |     """ | ||||||
|  |     Handle YOLO settings command-line interface (CLI) commands. | ||||||
|  | 
 | ||||||
|  |     This function processes YOLO settings CLI commands such as reset. | ||||||
|  |     It should be called when executing a script with arguments related to YOLO settings management. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         args (List[str]): A list of command line arguments for YOLO settings management. | ||||||
|  | 
 | ||||||
|  |     Example: | ||||||
|  |         python my_script.py yolo settings reset | ||||||
|  |     """ | ||||||
|  |     path = USER_CONFIG_DIR / 'settings.yaml'  # get SETTINGS YAML file path | ||||||
|  |     if any(args) and args[0] == 'reset': | ||||||
|  |         path.unlink()  # delete the settings file | ||||||
|  |         get_settings()  # create new settings | ||||||
|  |         LOGGER.info('Settings reset successfully')  # inform the user that settings have been reset | ||||||
|  |     yaml_print(path)  # print the current settings | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def entrypoint(debug=''): | def entrypoint(debug=''): | ||||||
|     """ |     """ | ||||||
|     This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed |     This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed | ||||||
| @ -211,8 +257,10 @@ def entrypoint(debug=''): | |||||||
|         'help': lambda: LOGGER.info(CLI_HELP_MSG), |         'help': lambda: LOGGER.info(CLI_HELP_MSG), | ||||||
|         'checks': checks.check_yolo, |         'checks': checks.check_yolo, | ||||||
|         'version': lambda: LOGGER.info(__version__), |         'version': lambda: LOGGER.info(__version__), | ||||||
|         'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'), |         'settings': lambda: handle_yolo_settings(args[1:]), | ||||||
|         'cfg': lambda: yaml_print(DEFAULT_CFG_PATH), |         'cfg': lambda: yaml_print(DEFAULT_CFG_PATH), | ||||||
|  |         'hub': lambda: handle_yolo_hub(args[1:]), | ||||||
|  |         'login': lambda: handle_yolo_hub(args), | ||||||
|         'copy-cfg': copy_default_cfg} |         'copy-cfg': copy_default_cfg} | ||||||
|     full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special} |     full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special} | ||||||
| 
 | 
 | ||||||
| @ -255,8 +303,8 @@ def entrypoint(debug=''): | |||||||
|             overrides['task'] = a |             overrides['task'] = a | ||||||
|         elif a in MODES: |         elif a in MODES: | ||||||
|             overrides['mode'] = a |             overrides['mode'] = a | ||||||
|         elif a in special: |         elif a.lower() in special: | ||||||
|             special[a]() |             special[a.lower()]() | ||||||
|             return |             return | ||||||
|         elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool): |         elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool): | ||||||
|             overrides[a] = True  # auto-True for default bool args, i.e. 'yolo show' sets show=True |             overrides[a] = True  # auto-True for default bool args, i.e. 'yolo show' sets show=True | ||||||
|  | |||||||
| @ -68,12 +68,14 @@ class YOLO: | |||||||
|         list(ultralytics.yolo.engine.results.Results): The prediction results. |         list(ultralytics.yolo.engine.results.Results): The prediction results. | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None, session=None) -> None: |     def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None: | ||||||
|         """ |         """ | ||||||
|         Initializes the YOLO model. |         Initializes the YOLO model. | ||||||
| 
 | 
 | ||||||
|         Args: |         Args: | ||||||
|             model (str, Path): model to load or create |             model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'. | ||||||
|  |             task (Any, optional): Task type for the YOLO model. Defaults to None. | ||||||
|  | 
 | ||||||
|         """ |         """ | ||||||
|         self._reset_callbacks() |         self._reset_callbacks() | ||||||
|         self.predictor = None  # reuse predictor |         self.predictor = None  # reuse predictor | ||||||
| @ -85,10 +87,16 @@ class YOLO: | |||||||
|         self.ckpt_path = None |         self.ckpt_path = None | ||||||
|         self.overrides = {}  # overrides for trainer object |         self.overrides = {}  # overrides for trainer object | ||||||
|         self.metrics = None  # validation/training metrics |         self.metrics = None  # validation/training metrics | ||||||
|         self.session = session  # HUB session |         self.session = None  # HUB session | ||||||
|  |         model = str(model).strip()  # strip spaces | ||||||
|  | 
 | ||||||
|  |         # Check if Ultralytics HUB model from https://hub.ultralytics.com | ||||||
|  |         if model.startswith('https://hub.ultralytics.com/models/'): | ||||||
|  |             from ultralytics.hub import HUBTrainingSession | ||||||
|  |             self.session = HUBTrainingSession(model) | ||||||
|  |             model = self.session.model_file | ||||||
| 
 | 
 | ||||||
|         # Load or create new YOLO model |         # Load or create new YOLO model | ||||||
|         model = str(model).strip()  # strip spaces |  | ||||||
|         suffix = Path(model).suffix |         suffix = Path(model).suffix | ||||||
|         if not suffix and Path(model).stem in GITHUB_ASSET_STEMS: |         if not suffix and Path(model).stem in GITHUB_ASSET_STEMS: | ||||||
|             model, suffix = Path(model).with_suffix('.pt'), '.pt'  # add suffix, i.e. yolov8n -> yolov8n.pt |             model, suffix = Path(model).with_suffix('.pt'), '.pt'  # add suffix, i.e. yolov8n -> yolov8n.pt | ||||||
| @ -280,6 +288,7 @@ class YOLO: | |||||||
|         from ultralytics.yolo.utils.benchmarks import benchmark |         from ultralytics.yolo.utils.benchmarks import benchmark | ||||||
|         overrides = self.model.args.copy() |         overrides = self.model.args.copy() | ||||||
|         overrides.update(kwargs) |         overrides.update(kwargs) | ||||||
|  |         overrides['mode'] = 'benchmark' | ||||||
|         overrides = {**DEFAULT_CFG_DICT, **overrides}  # fill in missing overrides keys with defaults |         overrides = {**DEFAULT_CFG_DICT, **overrides}  # fill in missing overrides keys with defaults | ||||||
|         return benchmark(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device']) |         return benchmark(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device']) | ||||||
| 
 | 
 | ||||||
| @ -293,6 +302,7 @@ class YOLO: | |||||||
|         self._check_is_pytorch_model() |         self._check_is_pytorch_model() | ||||||
|         overrides = self.overrides.copy() |         overrides = self.overrides.copy() | ||||||
|         overrides.update(kwargs) |         overrides.update(kwargs) | ||||||
|  |         overrides['mode'] = 'export' | ||||||
|         args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) |         args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) | ||||||
|         args.task = self.task |         args.task = self.task | ||||||
|         if args.imgsz == DEFAULT_CFG.imgsz: |         if args.imgsz == DEFAULT_CFG.imgsz: | ||||||
| @ -309,6 +319,11 @@ class YOLO: | |||||||
|             **kwargs (Any): Any number of arguments representing the training configuration. |             **kwargs (Any): Any number of arguments representing the training configuration. | ||||||
|         """ |         """ | ||||||
|         self._check_is_pytorch_model() |         self._check_is_pytorch_model() | ||||||
|  |         if self.session:  # Ultralytics HUB session | ||||||
|  |             if any(kwargs): | ||||||
|  |                 LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.') | ||||||
|  |             kwargs = self.session.train_args | ||||||
|  |             self.session.check_disk_space() | ||||||
|         check_pip_update_available() |         check_pip_update_available() | ||||||
|         overrides = self.overrides.copy() |         overrides = self.overrides.copy() | ||||||
|         overrides.update(kwargs) |         overrides.update(kwargs) | ||||||
|  | |||||||
| @ -277,6 +277,8 @@ class Masks(SimpleClass): | |||||||
|         self.masks = masks  # N, h, w |         self.masks = masks  # N, h, w | ||||||
|         self.orig_shape = orig_shape |         self.orig_shape = orig_shape | ||||||
| 
 | 
 | ||||||
|  |     @property | ||||||
|  |     @lru_cache(maxsize=1) | ||||||
|     def segments(self): |     def segments(self): | ||||||
|         # Segments-deprecated (normalized) |         # Segments-deprecated (normalized) | ||||||
|         LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and " |         LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and " | ||||||
|  | |||||||
| @ -321,10 +321,13 @@ def is_online() -> bool: | |||||||
|         bool: True if connection is successful, False otherwise. |         bool: True if connection is successful, False otherwise. | ||||||
|     """ |     """ | ||||||
|     import socket |     import socket | ||||||
|     with contextlib.suppress(Exception): | 
 | ||||||
|         host = socket.gethostbyname('www.github.com') |     for server in '1.1.1.1', '8.8.8.8', '223.5.5.5':  # Cloudflare, Google, AliDNS: | ||||||
|         socket.create_connection((host, 80), timeout=2) |         try: | ||||||
|         return True |             socket.create_connection((server, 53), timeout=2)  # connect to (server, port=53) | ||||||
|  |             return True | ||||||
|  |         except (socket.timeout, socket.gaierror, OSError): | ||||||
|  |             continue | ||||||
|     return False |     return False | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -586,7 +589,7 @@ def set_sentry(): | |||||||
|             logging.getLogger(logger).setLevel(logging.CRITICAL) |             logging.getLogger(logger).setLevel(logging.CRITICAL) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.2'): | def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.3'): | ||||||
|     """ |     """ | ||||||
|     Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist. |     Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist. | ||||||
| 
 | 
 | ||||||
| @ -609,8 +612,9 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.2'): | |||||||
|         'datasets_dir': str(datasets_root / 'datasets'),  # default datasets directory. |         'datasets_dir': str(datasets_root / 'datasets'),  # default datasets directory. | ||||||
|         'weights_dir': str(root / 'weights'),  # default weights directory. |         'weights_dir': str(root / 'weights'),  # default weights directory. | ||||||
|         'runs_dir': str(root / 'runs'),  # default runs directory. |         'runs_dir': str(root / 'runs'),  # default runs directory. | ||||||
|         'sync': True,  # sync analytics to help with YOLO development |  | ||||||
|         'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),  # anonymized uuid hash |         'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),  # anonymized uuid hash | ||||||
|  |         'sync': True,  # sync analytics to help with YOLO development | ||||||
|  |         'api_key': '',  # Ultralytics HUB API key (https://hub.ultralytics.com/) | ||||||
|         'settings_version': version}  # Ultralytics settings version |         'settings_version': version}  # Ultralytics settings version | ||||||
| 
 | 
 | ||||||
|     with torch_distributed_zero_first(RANK): |     with torch_distributed_zero_first(RANK): | ||||||
|  | |||||||
| @ -25,7 +25,7 @@ def on_pretrain_routine_end(trainer): | |||||||
|         mlflow_location = os.environ['MLFLOW_TRACKING_URI']  # "http://192.168.xxx.xxx:5000" |         mlflow_location = os.environ['MLFLOW_TRACKING_URI']  # "http://192.168.xxx.xxx:5000" | ||||||
|         mlflow.set_tracking_uri(mlflow_location) |         mlflow.set_tracking_uri(mlflow_location) | ||||||
| 
 | 
 | ||||||
|         experiment_name = trainer.args.project or 'YOLOv8' |         experiment_name = trainer.args.project or '/Shared/YOLOv8' | ||||||
|         experiment = mlflow.get_experiment_by_name(experiment_name) |         experiment = mlflow.get_experiment_by_name(experiment_name) | ||||||
|         if experiment is None: |         if experiment is None: | ||||||
|             mlflow.create_experiment(experiment_name) |             mlflow.create_experiment(experiment_name) | ||||||
| @ -33,16 +33,15 @@ def on_pretrain_routine_end(trainer): | |||||||
| 
 | 
 | ||||||
|         prefix = colorstr('MLFlow: ') |         prefix = colorstr('MLFlow: ') | ||||||
|         try: |         try: | ||||||
|             run, active_run = mlflow, mlflow.start_run() if mlflow else None |             run, active_run = mlflow, mlflow.active_run() | ||||||
|             if active_run is not None: |             if not active_run: | ||||||
|                 run_id = active_run.info.run_id |                 active_run = mlflow.start_run(experiment_id=experiment.experiment_id) | ||||||
|                 LOGGER.info(f'{prefix}Using run_id({run_id}) at {mlflow_location}') |             run_id = active_run.info.run_id | ||||||
|  |             LOGGER.info(f'{prefix}Using run_id({run_id}) at {mlflow_location}') | ||||||
|  |             run.log_params(vars(trainer.model.args)) | ||||||
|         except Exception as err: |         except Exception as err: | ||||||
|             LOGGER.error(f'{prefix}Failing init - {repr(err)}') |             LOGGER.error(f'{prefix}Failing init - {repr(err)}') | ||||||
|             LOGGER.warning(f'{prefix}Continuing without Mlflow') |             LOGGER.warning(f'{prefix}Continuing without Mlflow') | ||||||
|             run = None |  | ||||||
| 
 |  | ||||||
|         run.log_params(vars(trainer.model.args)) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def on_fit_epoch_end(trainer): | def on_fit_epoch_end(trainer): | ||||||
|  | |||||||
| @ -142,7 +142,7 @@ def check_pip_update_available(): | |||||||
|         bool: True if an update is available, False otherwise. |         bool: True if an update is available, False otherwise. | ||||||
|     """ |     """ | ||||||
|     if ONLINE and is_pip_package(): |     if ONLINE and is_pip_package(): | ||||||
|         with contextlib.suppress(ConnectionError): |         with contextlib.suppress(Exception): | ||||||
|             from ultralytics import __version__ |             from ultralytics import __version__ | ||||||
|             latest = check_latest_pypi_version() |             latest = check_latest_pypi_version() | ||||||
|             if pkg.parse_version(__version__) < pkg.parse_version(latest):  # update is available |             if pkg.parse_version(__version__) < pkg.parse_version(latest):  # update is available | ||||||
|  | |||||||
| @ -12,7 +12,7 @@ import requests | |||||||
| import torch | import torch | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
| 
 | 
 | ||||||
| from ultralytics.yolo.utils import LOGGER, checks, is_online | from ultralytics.yolo.utils import LOGGER, checks, emojis, is_online | ||||||
| 
 | 
 | ||||||
| GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \ | GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \ | ||||||
|                      [f'yolov5{size}u.pt' for size in 'nsmlx'] + \ |                      [f'yolov5{size}u.pt' for size in 'nsmlx'] + \ | ||||||
| @ -113,9 +113,9 @@ def safe_download(url, | |||||||
|                     f.unlink()  # remove partial downloads |                     f.unlink()  # remove partial downloads | ||||||
|             except Exception as e: |             except Exception as e: | ||||||
|                 if i == 0 and not is_online(): |                 if i == 0 and not is_online(): | ||||||
|                     raise ConnectionError(f'❌  Download failure for {url}. Environment is not online.') from e |                     raise ConnectionError(emojis(f'❌  Download failure for {url}. Environment is not online.')) from e | ||||||
|                 elif i >= retry: |                 elif i >= retry: | ||||||
|                     raise ConnectionError(f'❌  Download failure for {url}. Retry limit reached.') from e |                     raise ConnectionError(emojis(f'❌  Download failure for {url}. Retry limit reached.')) from e | ||||||
|                 LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...') |                 LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...') | ||||||
| 
 | 
 | ||||||
|     if unzip and f.exists() and f.suffix in ('.zip', '.tar', '.gz'): |     if unzip and f.exists() and f.suffix in ('.zip', '.tar', '.gz'): | ||||||
|  | |||||||
| @ -114,7 +114,7 @@ class Annotator: | |||||||
|             self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255 |             self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255 | ||||||
|         if im_gpu.device != masks.device: |         if im_gpu.device != masks.device: | ||||||
|             im_gpu = im_gpu.to(masks.device) |             im_gpu = im_gpu.to(masks.device) | ||||||
|         colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 |         colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0  # shape(n,3) | ||||||
|         colors = colors[:, None, None]  # shape(n,1,1,3) |         colors = colors[:, None, None]  # shape(n,1,1,3) | ||||||
|         masks = masks.unsqueeze(3)  # shape(n,h,w,1) |         masks = masks.unsqueeze(3)  # shape(n,h,w,1) | ||||||
|         masks_color = masks * (colors * alpha)  # shape(n,h,w,3) |         masks_color = masks * (colors * alpha)  # shape(n,h,w,3) | ||||||
|  | |||||||
| @ -78,7 +78,7 @@ class SegmentationPredictor(DetectionPredictor): | |||||||
|         for j, d in enumerate(reversed(det)): |         for j, d in enumerate(reversed(det)): | ||||||
|             c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) |             c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) | ||||||
|             if self.args.save_txt:  # Write to file |             if self.args.save_txt:  # Write to file | ||||||
|                 seg = mask.segments[len(det) - j - 1].copy().reshape(-1)  # reversed mask.segments, (n,2) to (n*2) |                 seg = mask.xyn[len(det) - j - 1].copy().reshape(-1)  # reversed mask.xyn, (n,2) to (n*2) | ||||||
|                 line = (c, *seg) + (conf, ) * self.args.save_conf + (() if id is None else (id, )) |                 line = (c, *seg) + (conf, ) * self.args.save_conf + (() if id is None else (id, )) | ||||||
|                 with open(f'{self.txt_path}.txt', 'a') as f: |                 with open(f'{self.txt_path}.txt', 'a') as f: | ||||||
|                     f.write(('%g ' * len(line)).rstrip() % line + '\n') |                     f.write(('%g ' * len(line)).rstrip() % line + '\n') | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Glenn Jocher
						Glenn Jocher