mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
8.0.60
new HUB training syntax (#1753)
Co-authored-by: Rafael Pierre <97888102+rafaelvp-db@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Semih Demirel <85176438+semihhdemirel@users.noreply.github.com>
This commit is contained in:
parent
e7876e1ba9
commit
84948651cd
28
.github/workflows/ci.yaml
vendored
28
.github/workflows/ci.yaml
vendored
@ -7,7 +7,7 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [main]
|
branches: [main, updates]
|
||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 * * *' # runs at 00:00 UTC every day
|
- cron: '0 0 * * *' # runs at 00:00 UTC every day
|
||||||
|
|
||||||
@ -43,16 +43,36 @@ jobs:
|
|||||||
python --version
|
python --version
|
||||||
pip --version
|
pip --version
|
||||||
pip list
|
pip list
|
||||||
- name: Test HUB training
|
- name: Test HUB training (Python Usage 1)
|
||||||
shell: python
|
shell: python
|
||||||
env:
|
env:
|
||||||
APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }}
|
APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }}
|
||||||
run: |
|
run: |
|
||||||
import os
|
import os
|
||||||
from ultralytics import hub
|
from pathlib import Path
|
||||||
|
from ultralytics import YOLO, hub
|
||||||
|
from ultralytics.yolo.utils import USER_CONFIG_DIR
|
||||||
|
Path(USER_CONFIG_DIR / 'settings.yaml').unlink()
|
||||||
key = os.environ['APIKEY']
|
key = os.environ['APIKEY']
|
||||||
hub.reset_model(key)
|
hub.reset_model(key)
|
||||||
hub.start(key)
|
model = YOLO('https://hub.ultralytics.com/models/' + key)
|
||||||
|
model.train()
|
||||||
|
- name: Test HUB training (Python Usage 2)
|
||||||
|
shell: python
|
||||||
|
env:
|
||||||
|
APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }}
|
||||||
|
run: |
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from ultralytics import YOLO, hub
|
||||||
|
from ultralytics.yolo.utils import USER_CONFIG_DIR
|
||||||
|
Path(USER_CONFIG_DIR / 'settings.yaml').unlink()
|
||||||
|
key = os.environ['APIKEY']
|
||||||
|
hub.reset_model(key)
|
||||||
|
key, model_id = key.split('_')
|
||||||
|
hub.login(key)
|
||||||
|
model = YOLO('https://hub.ultralytics.com/models/' + model_id)
|
||||||
|
model.train()
|
||||||
|
|
||||||
Benchmarks:
|
Benchmarks:
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
@ -26,6 +26,7 @@ WORKDIR /usr/src/ultralytics
|
|||||||
# Copy contents
|
# Copy contents
|
||||||
# COPY . /usr/src/app (issues as not a .git directory)
|
# COPY . /usr/src/app (issues as not a .git directory)
|
||||||
RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
|
RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
|
||||||
|
ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /usr/src/ultralytics/
|
||||||
|
|
||||||
# Install pip packages
|
# Install pip packages
|
||||||
RUN python3 -m pip install --upgrade pip wheel
|
RUN python3 -m pip install --upgrade pip wheel
|
||||||
|
@ -22,6 +22,7 @@ WORKDIR /usr/src/ultralytics
|
|||||||
# Copy contents
|
# Copy contents
|
||||||
# COPY . /usr/src/app (issues as not a .git directory)
|
# COPY . /usr/src/app (issues as not a .git directory)
|
||||||
RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
|
RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
|
||||||
|
ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /usr/src/ultralytics/
|
||||||
|
|
||||||
# Install pip packages
|
# Install pip packages
|
||||||
RUN python3 -m pip install --upgrade pip wheel
|
RUN python3 -m pip install --upgrade pip wheel
|
||||||
|
@ -22,6 +22,7 @@ WORKDIR /usr/src/ultralytics
|
|||||||
# Copy contents
|
# Copy contents
|
||||||
# COPY . /usr/src/app (issues as not a .git directory)
|
# COPY . /usr/src/app (issues as not a .git directory)
|
||||||
RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
|
RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics
|
||||||
|
ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /usr/src/ultralytics/
|
||||||
|
|
||||||
# Install pip packages
|
# Install pip packages
|
||||||
RUN python3 -m pip install --upgrade pip wheel
|
RUN python3 -m pip install --upgrade pip wheel
|
||||||
|
@ -17,7 +17,7 @@ passing `stream=True` in the predictor's call method.
|
|||||||
probs = result.probs # Class probabilities for classification outputs
|
probs = result.probs # Class probabilities for classification outputs
|
||||||
```
|
```
|
||||||
|
|
||||||
=== "Return a list with `Stream=True`"
|
=== "Return a generator with `Stream=True`"
|
||||||
```python
|
```python
|
||||||
inputs = [img, img] # list of numpy arrays
|
inputs = [img, img] # list of numpy arrays
|
||||||
results = model(inputs, stream=True) # generator of Results objects
|
results = model(inputs, stream=True) # generator of Results objects
|
||||||
@ -54,6 +54,40 @@ whether each source can be used in streaming mode with `stream=True` ✅ and an
|
|||||||
| YouTube ✅ | `'https://youtu.be/Zgi9g1ksQHc'` | `str` | |
|
| YouTube ✅ | `'https://youtu.be/Zgi9g1ksQHc'` | `str` | |
|
||||||
| stream ✅ | `'rtsp://example.com/media.mp4'` | `str` | RTSP, RTMP, HTTP |
|
| stream ✅ | `'rtsp://example.com/media.mp4'` | `str` | RTSP, RTMP, HTTP |
|
||||||
|
|
||||||
|
|
||||||
|
## Arguments
|
||||||
|
`model.predict` accepts multiple arguments that control the predction operation. These arguments can be passed directly to `model.predict`:
|
||||||
|
!!! example
|
||||||
|
```
|
||||||
|
model.predict(source, save=True, imgsz=320, conf=0.5)
|
||||||
|
```
|
||||||
|
|
||||||
|
All supported arguments:
|
||||||
|
|
||||||
|
| Key | Value | Description |
|
||||||
|
|------------------|------------------------|----------------------------------------------------------|
|
||||||
|
| `source` | `'ultralytics/assets'` | source directory for images or videos |
|
||||||
|
| `conf` | `0.25` | object confidence threshold for detection |
|
||||||
|
| `iou` | `0.7` | intersection over union (IoU) threshold for NMS |
|
||||||
|
| `half` | `False` | use half precision (FP16) |
|
||||||
|
| `device` | `None` | device to run on, i.e. cuda device=0/1/2/3 or device=cpu |
|
||||||
|
| `show` | `False` | show results if possible |
|
||||||
|
| `save` | `False` | save images with results |
|
||||||
|
| `save_txt` | `False` | save results as .txt file |
|
||||||
|
| `save_conf` | `False` | save results with confidence scores |
|
||||||
|
| `save_crop` | `False` | save cropped images with results |
|
||||||
|
| `hide_labels` | `False` | hide labels |
|
||||||
|
| `hide_conf` | `False` | hide confidence scores |
|
||||||
|
| `max_det` | `300` | maximum number of detections per image |
|
||||||
|
| `vid_stride` | `False` | video frame-rate stride |
|
||||||
|
| `line_thickness` | `3` | bounding box thickness (pixels) |
|
||||||
|
| `visualize` | `False` | visualize model features |
|
||||||
|
| `augment` | `False` | apply image augmentation to prediction sources |
|
||||||
|
| `agnostic_nms` | `False` | class-agnostic NMS |
|
||||||
|
| `retina_masks` | `False` | use high-resolution segmentation masks |
|
||||||
|
| `classes` | `None` | filter results by class, i.e. class=0, or class=[0,2,3] |
|
||||||
|
| `boxes` | `True` | Show boxes in segmentation predictions |
|
||||||
|
|
||||||
## Image and Video Formats
|
## Image and Video Formats
|
||||||
|
|
||||||
YOLOv8 supports various image and video formats, as specified
|
YOLOv8 supports various image and video formats, as specified
|
||||||
|
@ -96,7 +96,6 @@ names:
|
|||||||
77: teddy bear
|
77: teddy bear
|
||||||
78: hair drier
|
78: hair drier
|
||||||
79: toothbrush
|
79: toothbrush
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ Just simply clone and run
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
python main.py
|
python main.py --model yolov8n.onnx --img image.jpg
|
||||||
```
|
```
|
||||||
|
|
||||||
If you start from scratch:
|
If you start from scratch:
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
import cv2.dnn
|
import cv2.dnn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -16,9 +18,9 @@ def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h):
|
|||||||
cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(onnx_model, input_image):
|
||||||
model: cv2.dnn.Net = cv2.dnn.readNetFromONNX('yolov8n.onnx')
|
model: cv2.dnn.Net = cv2.dnn.readNetFromONNX(onnx_model)
|
||||||
original_image: np.ndarray = cv2.imread(str(ROOT / 'assets/bus.jpg'))
|
original_image: np.ndarray = cv2.imread(input_image)
|
||||||
[height, width, _] = original_image.shape
|
[height, width, _] = original_image.shape
|
||||||
length = max((height, width))
|
length = max((height, width))
|
||||||
image = np.zeros((length, length, 3), np.uint8)
|
image = np.zeros((length, length, 3), np.uint8)
|
||||||
@ -71,4 +73,8 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--model', default='yolov8n.onnx', help='Input your onnx model.')
|
||||||
|
parser.add_argument('--img', default=str(ROOT / 'assets/bus.jpg'), help='Path to input image.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args.model, args.img)
|
||||||
|
@ -46,7 +46,7 @@ theme:
|
|||||||
- content.tabs.link # all code tabs change simultaneously
|
- content.tabs.link # all code tabs change simultaneously
|
||||||
|
|
||||||
# Customization
|
# Customization
|
||||||
copyright: Ultralytics 2023. All rights reserved.
|
copyright: <a href="https://ultralytics.com" target="_blank">Ultralytics 2023.</a> All rights reserved.
|
||||||
extra:
|
extra:
|
||||||
# version:
|
# version:
|
||||||
# provider: mike # version drop-down menu
|
# provider: mike # version drop-down menu
|
||||||
@ -167,7 +167,7 @@ nav:
|
|||||||
- Hyperparameter evolution: yolov5/hyp_evolution.md
|
- Hyperparameter evolution: yolov5/hyp_evolution.md
|
||||||
- Transfer learning with frozen layers: yolov5/transfer_learn_frozen.md
|
- Transfer learning with frozen layers: yolov5/transfer_learn_frozen.md
|
||||||
- Architecture Summary: yolov5/architecture.md
|
- Architecture Summary: yolov5/architecture.md
|
||||||
- Roboflow for Datasets, Labeling, and Active Learning: yolov5/roboflow.md
|
- Roboflow Datasets: yolov5/roboflow.md
|
||||||
- Neural Magic's DeepSparse: yolov5/neural_magic.md
|
- Neural Magic's DeepSparse: yolov5/neural_magic.md
|
||||||
- Comet Logging: yolov5/comet.md
|
- Comet Logging: yolov5/comet.md
|
||||||
- Clearml Logging: yolov5/clearml.md
|
- Clearml Logging: yolov5/clearml.md
|
||||||
|
2
setup.py
2
setup.py
@ -58,7 +58,7 @@ setup(
|
|||||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||||
'Topic :: Scientific/Engineering :: Image Recognition',
|
'Topic :: Scientific/Engineering :: Image Recognition',
|
||||||
'Operating System :: POSIX :: Linux',
|
'Operating System :: POSIX :: Linux',
|
||||||
'Operating System :: macOS',
|
'Operating System :: MacOS',
|
||||||
'Operating System :: Microsoft :: Windows', ],
|
'Operating System :: Microsoft :: Windows', ],
|
||||||
keywords='machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics',
|
keywords='machine-learning, deep-learning, vision, ML, DL, AI, YOLO, YOLOv3, YOLOv5, YOLOv8, HUB, Ultralytics',
|
||||||
entry_points={
|
entry_points={
|
||||||
|
@ -56,11 +56,11 @@ def test_predict_detect():
|
|||||||
|
|
||||||
|
|
||||||
def test_predict_segment():
|
def test_predict_segment():
|
||||||
run(f"yolo predict model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32 save")
|
run(f"yolo predict model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32 save save_txt")
|
||||||
|
|
||||||
|
|
||||||
def test_predict_classify():
|
def test_predict_classify():
|
||||||
run(f"yolo predict model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32 save")
|
run(f"yolo predict model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32 save save_txt")
|
||||||
|
|
||||||
|
|
||||||
# Export checks --------------------------------------------------------------------------------------------------------
|
# Export checks --------------------------------------------------------------------------------------------------------
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = '8.0.59'
|
__version__ = '8.0.60'
|
||||||
|
|
||||||
|
from ultralytics.hub import start
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
from ultralytics.yolo.utils.checks import check_yolo as checks
|
from ultralytics.yolo.utils.checks import check_yolo as checks
|
||||||
|
|
||||||
__all__ = '__version__', 'YOLO', 'checks' # allow simpler import
|
__all__ = '__version__', 'YOLO', 'checks', 'start' # allow simpler import
|
||||||
|
@ -2,47 +2,51 @@
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ultralytics.hub.auth import Auth
|
|
||||||
from ultralytics.hub.session import HUBTrainingSession
|
|
||||||
from ultralytics.hub.utils import PREFIX, split_key
|
from ultralytics.hub.utils import PREFIX, split_key
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.utils import LOGGER
|
||||||
from ultralytics.yolo.utils import LOGGER, emojis
|
|
||||||
|
|
||||||
|
def login(api_key=''):
|
||||||
|
"""
|
||||||
|
Log in to the Ultralytics HUB API using the provided API key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from ultralytics import hub
|
||||||
|
hub.login('your_api_key')
|
||||||
|
"""
|
||||||
|
from ultralytics.hub.auth import Auth
|
||||||
|
Auth(api_key)
|
||||||
|
|
||||||
|
|
||||||
|
def logout():
|
||||||
|
"""
|
||||||
|
Logout Ultralytics HUB
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from ultralytics import hub
|
||||||
|
hub.logout()
|
||||||
|
"""
|
||||||
|
LOGGER.warning('WARNING ⚠️ This method is not yet implemented.')
|
||||||
|
|
||||||
|
|
||||||
def start(key=''):
|
def start(key=''):
|
||||||
"""
|
"""
|
||||||
Start training models with Ultralytics HUB. Usage: from ultralytics.hub import start; start('API_KEY')
|
Start training models with Ultralytics HUB (DEPRECATED).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (str, optional): A string containing either the API key and model ID combination (apikey_modelid),
|
||||||
|
or the full model URL (https://hub.ultralytics.com/models/apikey_modelid).
|
||||||
"""
|
"""
|
||||||
auth = Auth(key)
|
LOGGER.warning(f"""
|
||||||
model_id = split_key(key)[1] if auth.get_state() else request_api_key(auth)
|
WARNING ⚠️ ultralytics.start() is deprecated in 8.0.60. Updated usage to train your Ultralytics HUB model is below:
|
||||||
if not model_id:
|
|
||||||
raise ConnectionError(emojis('Connecting with global API key is not currently supported. ❌'))
|
|
||||||
|
|
||||||
session = HUBTrainingSession(model_id=model_id, auth=auth)
|
from ultralytics import YOLO
|
||||||
session.check_disk_space()
|
|
||||||
|
|
||||||
model = YOLO(model=session.model_file, session=session)
|
model = YOLO('https://hub.ultralytics.com/models/{key}')
|
||||||
model.train(**session.train_args)
|
model.train()""")
|
||||||
|
|
||||||
|
|
||||||
def request_api_key(auth, max_attempts=3):
|
|
||||||
"""
|
|
||||||
Prompt the user to input their API key. Returns the model ID.
|
|
||||||
"""
|
|
||||||
import getpass
|
|
||||||
for attempts in range(max_attempts):
|
|
||||||
LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
|
|
||||||
input_key = getpass.getpass(
|
|
||||||
'Enter your Ultralytics API Key from https://hub.ultralytics.com/settings?tab=api+keys:\n')
|
|
||||||
auth.api_key, model_id = split_key(input_key)
|
|
||||||
|
|
||||||
if auth.authenticate():
|
|
||||||
LOGGER.info(f'{PREFIX}Authenticated ✅')
|
|
||||||
return model_id
|
|
||||||
|
|
||||||
LOGGER.warning(f'{PREFIX}Invalid API key ⚠️\n')
|
|
||||||
|
|
||||||
raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
|
|
||||||
|
|
||||||
|
|
||||||
def reset_model(key=''):
|
def reset_model(key=''):
|
||||||
|
@ -2,27 +2,74 @@
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ultralytics.hub.utils import HUB_API_ROOT, request_with_credentials
|
from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, request_with_credentials
|
||||||
from ultralytics.yolo.utils import is_colab
|
from ultralytics.yolo.utils import LOGGER, SETTINGS, emojis, is_colab, set_settings
|
||||||
|
|
||||||
API_KEY_PATH = 'https://hub.ultralytics.com/settings?tab=api+keys'
|
API_KEY_URL = 'https://hub.ultralytics.com/settings?tab=api+keys'
|
||||||
|
|
||||||
|
|
||||||
class Auth:
|
class Auth:
|
||||||
id_token = api_key = model_key = False
|
id_token = api_key = model_key = False
|
||||||
|
|
||||||
def __init__(self, api_key=None):
|
def __init__(self, api_key=''):
|
||||||
self.api_key = self._clean_api_key(api_key)
|
"""
|
||||||
self.authenticate() if self.api_key else self.auth_with_cookies()
|
Initialize the Auth class with an optional API key.
|
||||||
|
|
||||||
@staticmethod
|
Args:
|
||||||
def _clean_api_key(key: str) -> str:
|
api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
|
||||||
"""Strip model from key if present"""
|
"""
|
||||||
separator = '_'
|
# Split the input API key in case it contains a combined key_model and keep only the API key part
|
||||||
return key.split(separator)[0] if separator in key else key
|
api_key = api_key.split('_')[0]
|
||||||
|
|
||||||
|
# Set API key attribute as value passed or SETTINGS API key if none passed
|
||||||
|
self.api_key = api_key or SETTINGS.get('api_key', '')
|
||||||
|
|
||||||
|
# If an API key is provided
|
||||||
|
if self.api_key:
|
||||||
|
# If the provided API key matches the API key in the SETTINGS
|
||||||
|
if self.api_key == SETTINGS.get('api_key'):
|
||||||
|
# Log that the user is already logged in
|
||||||
|
LOGGER.info(f'{PREFIX}Authenticated ✅')
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# Attempt to authenticate with the provided API key
|
||||||
|
success = self.authenticate()
|
||||||
|
# If the API key is not provided and the environment is a Google Colab notebook
|
||||||
|
elif is_colab():
|
||||||
|
# Attempt to authenticate using browser cookies
|
||||||
|
success = self.auth_with_cookies()
|
||||||
|
else:
|
||||||
|
# Request an API key
|
||||||
|
success = self.request_api_key()
|
||||||
|
|
||||||
|
# Update SETTINGS with the new API key after successful authentication
|
||||||
|
if success:
|
||||||
|
set_settings({'api_key': self.api_key})
|
||||||
|
# Log that the new login was successful
|
||||||
|
LOGGER.info(f'{PREFIX}New authentication successful ✅')
|
||||||
|
else:
|
||||||
|
LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}')
|
||||||
|
|
||||||
|
def request_api_key(self, max_attempts=3):
|
||||||
|
"""
|
||||||
|
Prompt the user to input their API key. Returns the model ID.
|
||||||
|
"""
|
||||||
|
import getpass
|
||||||
|
for attempts in range(max_attempts):
|
||||||
|
LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
|
||||||
|
input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ')
|
||||||
|
self.api_key = input_key.split('_')[0] # remove model id if present
|
||||||
|
if self.authenticate():
|
||||||
|
return True
|
||||||
|
raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
|
||||||
|
|
||||||
def authenticate(self) -> bool:
|
def authenticate(self) -> bool:
|
||||||
"""Attempt to authenticate with server"""
|
"""
|
||||||
|
Attempt to authenticate with the server using either id_token or API key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if authentication is successful, False otherwise.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
header = self.get_auth_header()
|
header = self.get_auth_header()
|
||||||
if header:
|
if header:
|
||||||
@ -33,12 +80,16 @@ class Auth:
|
|||||||
raise ConnectionError('User has not authenticated locally.')
|
raise ConnectionError('User has not authenticated locally.')
|
||||||
except ConnectionError:
|
except ConnectionError:
|
||||||
self.id_token = self.api_key = False # reset invalid
|
self.id_token = self.api_key = False # reset invalid
|
||||||
|
LOGGER.warning(f'{PREFIX}Invalid API key ⚠️')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def auth_with_cookies(self) -> bool:
|
def auth_with_cookies(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Attempt to fetch authentication via cookies and set id_token.
|
Attempt to fetch authentication via cookies and set id_token.
|
||||||
User must be logged in to HUB and running in a supported browser.
|
User must be logged in to HUB and running in a supported browser.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if authentication is successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
if not is_colab():
|
if not is_colab():
|
||||||
return False # Currently only works with Colab
|
return False # Currently only works with Colab
|
||||||
@ -54,6 +105,12 @@ class Auth:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def get_auth_header(self):
|
def get_auth_header(self):
|
||||||
|
"""
|
||||||
|
Get the authentication header for making API requests.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The authentication header if id_token or API key is set, None otherwise.
|
||||||
|
"""
|
||||||
if self.id_token:
|
if self.id_token:
|
||||||
return {'authorization': f'Bearer {self.id_token}'}
|
return {'authorization': f'Bearer {self.id_token}'}
|
||||||
elif self.api_key:
|
elif self.api_key:
|
||||||
@ -62,9 +119,19 @@ class Auth:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_state(self) -> bool:
|
def get_state(self) -> bool:
|
||||||
"""Get the authentication state"""
|
"""
|
||||||
|
Get the authentication state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if either id_token or API key is set, False otherwise.
|
||||||
|
"""
|
||||||
return self.id_token or self.api_key
|
return self.id_token or self.api_key
|
||||||
|
|
||||||
def set_api_key(self, key: str):
|
def set_api_key(self, key: str):
|
||||||
"""Get the authentication state"""
|
"""
|
||||||
|
Set the API key for authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (str): The API key string.
|
||||||
|
"""
|
||||||
self.api_key = key
|
self.api_key = key
|
||||||
|
@ -6,17 +6,62 @@ from time import sleep
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request
|
from ultralytics.hub.utils import HUB_API_ROOT, PREFIX, check_dataset_disk_space, smart_request
|
||||||
from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, checks, emojis, is_colab, threaded
|
from ultralytics.yolo.utils import LOGGER, __version__, checks, emojis, is_colab, threaded
|
||||||
|
|
||||||
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
|
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
|
||||||
|
|
||||||
|
|
||||||
class HUBTrainingSession:
|
class HUBTrainingSession:
|
||||||
|
"""
|
||||||
|
HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing.
|
||||||
|
|
||||||
def __init__(self, model_id, auth):
|
Args:
|
||||||
|
url (str): Model identifier used to initialize the HUB training session.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
agent_id (str): Identifier for the instance communicating with the server.
|
||||||
|
model_id (str): Identifier for the YOLOv5 model being trained.
|
||||||
|
model_url (str): URL for the model in Ultralytics HUB.
|
||||||
|
api_url (str): API URL for the model in Ultralytics HUB.
|
||||||
|
auth_header (Dict): Authentication header for the Ultralytics HUB API requests.
|
||||||
|
rate_limits (Dict): Rate limits for different API calls (in seconds).
|
||||||
|
timers (Dict): Timers for rate limiting.
|
||||||
|
metrics_queue (Dict): Queue for the model's metrics.
|
||||||
|
model (Dict): Model data fetched from Ultralytics HUB.
|
||||||
|
alive (bool): Indicates if the heartbeat loop is active.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, url):
|
||||||
|
"""
|
||||||
|
Initialize the HUBTrainingSession with the provided model identifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): Model identifier used to initialize the HUB training session.
|
||||||
|
It can be a URL string or a model key with specific format.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the provided model identifier is invalid.
|
||||||
|
ConnectionError: If connecting with global API key is not supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from ultralytics.hub.auth import Auth
|
||||||
|
|
||||||
|
# Parse input
|
||||||
|
if url.startswith('https://hub.ultralytics.com/models/'):
|
||||||
|
url = url.split('https://hub.ultralytics.com/models/')[-1]
|
||||||
|
if [len(x) for x in url.split('_')] == [42, 20]:
|
||||||
|
key, model_id = url.split('_')
|
||||||
|
elif len(url) == 20:
|
||||||
|
key, model_id = '', url
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Invalid HUBTrainingSession input: {url}')
|
||||||
|
|
||||||
|
# Authorize
|
||||||
|
auth = Auth(key)
|
||||||
self.agent_id = None # identifies which instance is communicating with server
|
self.agent_id = None # identifies which instance is communicating with server
|
||||||
self.model_id = model_id
|
self.model_id = model_id
|
||||||
|
self.model_url = f'https://hub.ultralytics.com/models/{model_id}'
|
||||||
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
|
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
|
||||||
self.auth_header = auth.get_auth_header()
|
self.auth_header = auth.get_auth_header()
|
||||||
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
|
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
|
||||||
@ -26,16 +71,17 @@ class HUBTrainingSession:
|
|||||||
self.alive = True
|
self.alive = True
|
||||||
self._start_heartbeat() # start heartbeats
|
self._start_heartbeat() # start heartbeats
|
||||||
self._register_signal_handlers()
|
self._register_signal_handlers()
|
||||||
|
LOGGER.info(f'{PREFIX}View model at {self.model_url} 🚀')
|
||||||
|
|
||||||
def _register_signal_handlers(self):
|
def _register_signal_handlers(self):
|
||||||
|
"""Register signal handlers for SIGTERM and SIGINT signals to gracefully handle termination."""
|
||||||
signal.signal(signal.SIGTERM, self._handle_signal)
|
signal.signal(signal.SIGTERM, self._handle_signal)
|
||||||
signal.signal(signal.SIGINT, self._handle_signal)
|
signal.signal(signal.SIGINT, self._handle_signal)
|
||||||
|
|
||||||
def _handle_signal(self, signum, frame):
|
def _handle_signal(self, signum, frame):
|
||||||
"""
|
"""
|
||||||
Prevent heartbeats from being sent on Colab after kill.
|
Handle kill signals and prevent heartbeats from being sent on Colab after termination.
|
||||||
This method does not use frame, it is included as it is
|
This method does not use frame, it is included as it is passed by signal.
|
||||||
passed by signal.
|
|
||||||
"""
|
"""
|
||||||
if self.alive is True:
|
if self.alive is True:
|
||||||
LOGGER.info(f'{PREFIX}Kill signal received! ❌')
|
LOGGER.info(f'{PREFIX}Kill signal received! ❌')
|
||||||
@ -43,15 +89,16 @@ class HUBTrainingSession:
|
|||||||
sys.exit(signum)
|
sys.exit(signum)
|
||||||
|
|
||||||
def _stop_heartbeat(self):
|
def _stop_heartbeat(self):
|
||||||
"""End the heartbeat loop"""
|
"""Terminate the heartbeat loop."""
|
||||||
self.alive = False
|
self.alive = False
|
||||||
|
|
||||||
def upload_metrics(self):
|
def upload_metrics(self):
|
||||||
|
"""Upload model metrics to Ultralytics HUB."""
|
||||||
payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
|
payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
|
||||||
smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
|
smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
|
||||||
|
|
||||||
def _get_model(self):
|
def _get_model(self):
|
||||||
# Returns model from database by id
|
"""Fetch and return model data from Ultralytics HUB."""
|
||||||
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
|
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -59,9 +106,7 @@ class HUBTrainingSession:
|
|||||||
data = response.json().get('data', None)
|
data = response.json().get('data', None)
|
||||||
|
|
||||||
if data.get('status', None) == 'trained':
|
if data.get('status', None) == 'trained':
|
||||||
raise ValueError(
|
raise ValueError(emojis(f'Model is already trained and uploaded to {self.model_url} 🚀'))
|
||||||
emojis(f'Model is already trained and uploaded to '
|
|
||||||
f'https://hub.ultralytics.com/models/{self.model_id} 🚀'))
|
|
||||||
|
|
||||||
if not data.get('data', None):
|
if not data.get('data', None):
|
||||||
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
|
raise ValueError('Dataset may still be processing. Please wait a minute and try again.') # RF fix
|
||||||
@ -88,11 +133,21 @@ class HUBTrainingSession:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def check_disk_space(self):
|
def check_disk_space(self):
|
||||||
if not check_dataset_disk_space(self.model['data']):
|
"""Check if there is enough disk space for the dataset."""
|
||||||
|
if not check_dataset_disk_space(url=self.model['data']):
|
||||||
raise MemoryError('Not enough disk space')
|
raise MemoryError('Not enough disk space')
|
||||||
|
|
||||||
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
||||||
# Upload a model to HUB
|
"""
|
||||||
|
Upload a model checkpoint to Ultralytics HUB.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch (int): The current training epoch.
|
||||||
|
weights (str): Path to the model weights file.
|
||||||
|
is_best (bool): Indicates if the current model is the best one so far.
|
||||||
|
map (float): Mean average precision of the model.
|
||||||
|
final (bool): Indicates if the model is the final model after training.
|
||||||
|
"""
|
||||||
if Path(weights).is_file():
|
if Path(weights).is_file():
|
||||||
with open(weights, 'rb') as f:
|
with open(weights, 'rb') as f:
|
||||||
file = f.read()
|
file = f.read()
|
||||||
@ -120,6 +175,7 @@ class HUBTrainingSession:
|
|||||||
|
|
||||||
@threaded
|
@threaded
|
||||||
def _start_heartbeat(self):
|
def _start_heartbeat(self):
|
||||||
|
"""Begin a threaded heartbeat loop to report the agent's status to Ultralytics HUB."""
|
||||||
while self.alive:
|
while self.alive:
|
||||||
r = smart_request('post',
|
r = smart_request('post',
|
||||||
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
|
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
|
||||||
|
@ -22,7 +22,16 @@ HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.co
|
|||||||
|
|
||||||
|
|
||||||
def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0):
|
def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', sf=2.0):
|
||||||
# Check that url fits on disk with safety factor sf, i.e. require 2GB free if url size is 1GB with sf=2.0
|
"""
|
||||||
|
Check if there is sufficient disk space to download and store a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str, optional): The URL to the dataset file. Defaults to 'https://ultralytics.com/assets/coco128.zip'.
|
||||||
|
sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 2.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if there is sufficient disk space, False otherwise.
|
||||||
|
"""
|
||||||
gib = 1 << 30 # bytes per GiB
|
gib = 1 << 30 # bytes per GiB
|
||||||
data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB)
|
data = int(requests.head(url).headers['Content-Length']) / gib # dataset size (GB)
|
||||||
total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes
|
total, used, free = (x / gib for x in shutil.disk_usage('/')) # bytes
|
||||||
@ -35,7 +44,18 @@ def check_dataset_disk_space(url='https://ultralytics.com/assets/coco128.zip', s
|
|||||||
|
|
||||||
|
|
||||||
def request_with_credentials(url: str) -> any:
|
def request_with_credentials(url: str) -> any:
|
||||||
""" Make an ajax request with cookies attached """
|
"""
|
||||||
|
Make an AJAX request with cookies attached in a Google Colab environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL to make the request to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
any: The response data from the AJAX request.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OSError: If the function is not run in a Google Colab environment.
|
||||||
|
"""
|
||||||
if not is_colab():
|
if not is_colab():
|
||||||
raise OSError('request_with_credentials() must run in a Colab environment')
|
raise OSError('request_with_credentials() must run in a Colab environment')
|
||||||
from google.colab import output # noqa
|
from google.colab import output # noqa
|
||||||
@ -95,7 +115,6 @@ def requests_with_progress(method, url, **kwargs):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
requests.Response: The response from the HTTP request.
|
requests.Response: The response from the HTTP request.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
progress = kwargs.pop('progress', False)
|
progress = kwargs.pop('progress', False)
|
||||||
if not progress:
|
if not progress:
|
||||||
@ -126,7 +145,6 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
requests.Response: The HTTP response object. If the request is executed in a separate thread, returns None.
|
requests.Response: The HTTP response object. If the request is executed in a separate thread, returns None.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
retry_codes = (408, 500) # retry only these codes
|
retry_codes = (408, 500) # retry only these codes
|
||||||
|
|
||||||
@ -171,8 +189,8 @@ class Traces:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""
|
"""
|
||||||
Initialize Traces for error tracking and reporting if tests are not currently running.
|
Initialize Traces for error tracking and reporting if tests are not currently running.
|
||||||
|
Sets the rate limit, timer, and metadata attributes, and determines whether Traces are enabled.
|
||||||
"""
|
"""
|
||||||
from ultralytics.yolo.cfg import MODES, TASKS
|
|
||||||
self.rate_limit = 60.0 # rate limit (seconds)
|
self.rate_limit = 60.0 # rate limit (seconds)
|
||||||
self.t = 0.0 # rate limit timer (seconds)
|
self.t = 0.0 # rate limit timer (seconds)
|
||||||
self.metadata = {
|
self.metadata = {
|
||||||
@ -187,17 +205,22 @@ class Traces:
|
|||||||
not TESTS_RUNNING and \
|
not TESTS_RUNNING and \
|
||||||
ONLINE and \
|
ONLINE and \
|
||||||
(is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
|
(is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
|
||||||
self.usage = {'tasks': {k: 0 for k in TASKS}, 'modes': {k: 0 for k in MODES}}
|
self._reset_usage()
|
||||||
|
|
||||||
def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0):
|
def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0):
|
||||||
"""
|
"""
|
||||||
Sync traces data if enabled in the global settings
|
Sync traces data if enabled in the global settings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (IterableSimpleNamespace): Configuration for the task and mode.
|
cfg (IterableSimpleNamespace): Configuration for the task and mode.
|
||||||
all_keys (bool): Sync all items, not just non-default values.
|
all_keys (bool): Sync all items, not just non-default values.
|
||||||
traces_sample_rate (float): Fraction of traces captured from 0.0 to 1.0
|
traces_sample_rate (float): Fraction of traces captured from 0.0 to 1.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Increment usage
|
||||||
|
self.usage['modes'][cfg.mode] = self.usage['modes'].get(cfg.mode, 0) + 1
|
||||||
|
self.usage['tasks'][cfg.task] = self.usage['tasks'].get(cfg.task, 0) + 1
|
||||||
|
|
||||||
t = time.time() # current time
|
t = time.time() # current time
|
||||||
if not self.enabled or random() > traces_sample_rate:
|
if not self.enabled or random() > traces_sample_rate:
|
||||||
# Traces disabled or not randomly selected, do nothing
|
# Traces disabled or not randomly selected, do nothing
|
||||||
@ -207,18 +230,20 @@ class Traces:
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
# Time is over rate limiter, send trace now
|
# Time is over rate limiter, send trace now
|
||||||
self.t = t # reset rate limit timer
|
trace = {'uuid': SETTINGS['uuid'], 'usage': self.usage.copy(), 'metadata': self.metadata}
|
||||||
|
|
||||||
# Build trace
|
|
||||||
if cfg.task in self.usage['tasks']:
|
|
||||||
self.usage['tasks'][cfg.task] += 1
|
|
||||||
if cfg.mode in self.usage['modes']:
|
|
||||||
self.usage['modes'][cfg.mode] += 1
|
|
||||||
trace = {'uuid': SETTINGS['uuid'], 'usage': self.usage, 'metadata': self.metadata}
|
|
||||||
|
|
||||||
# Send a request to the HUB API to sync analytics
|
# Send a request to the HUB API to sync analytics
|
||||||
smart_request('post', f'{HUB_API_ROOT}/v1/usage/anonymous', json=trace, code=3, retry=0, verbose=False)
|
smart_request('post', f'{HUB_API_ROOT}/v1/usage/anonymous', json=trace, code=3, retry=0, verbose=False)
|
||||||
|
|
||||||
|
# Reset usage and rate limit timer
|
||||||
|
self._reset_usage()
|
||||||
|
self.t = t
|
||||||
|
|
||||||
|
def _reset_usage(self):
|
||||||
|
"""Reset the usage dictionary by initializing keys for each task and mode with a value of 0."""
|
||||||
|
from ultralytics.yolo.cfg import MODES, TASKS
|
||||||
|
self.usage = {'tasks': {k: 0 for k in TASKS}, 'modes': {k: 0 for k in MODES}}
|
||||||
|
|
||||||
|
|
||||||
# Run below code on hub/utils init -------------------------------------------------------------------------------------
|
# Run below code on hub/utils init -------------------------------------------------------------------------------------
|
||||||
traces = Traces()
|
traces = Traces()
|
||||||
|
@ -9,7 +9,8 @@ from types import SimpleNamespace
|
|||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR,
|
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, ROOT, USER_CONFIG_DIR,
|
||||||
IterableSimpleNamespace, __version__, checks, colorstr, yaml_load, yaml_print)
|
IterableSimpleNamespace, __version__, checks, colorstr, get_settings, yaml_load,
|
||||||
|
yaml_print)
|
||||||
|
|
||||||
# Define valid tasks and modes
|
# Define valid tasks and modes
|
||||||
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
||||||
@ -187,6 +188,51 @@ def merge_equals_args(args: List[str]) -> List[str]:
|
|||||||
return new_args
|
return new_args
|
||||||
|
|
||||||
|
|
||||||
|
def handle_yolo_hub(args: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
Handle Ultralytics HUB command-line interface (CLI) commands.
|
||||||
|
|
||||||
|
This function processes Ultralytics HUB CLI commands such as login and logout.
|
||||||
|
It should be called when executing a script with arguments related to HUB authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (List[str]): A list of command line arguments
|
||||||
|
|
||||||
|
Example:
|
||||||
|
python my_script.py hub login your_api_key
|
||||||
|
"""
|
||||||
|
from ultralytics import hub
|
||||||
|
|
||||||
|
if args[0] == 'login':
|
||||||
|
key = args[1] if len(args) > 1 else ''
|
||||||
|
# Log in to Ultralytics HUB using the provided API key
|
||||||
|
hub.login(key)
|
||||||
|
elif args[0] == 'logout':
|
||||||
|
# Log out from Ultralytics HUB
|
||||||
|
hub.logout()
|
||||||
|
|
||||||
|
|
||||||
|
def handle_yolo_settings(args: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
Handle YOLO settings command-line interface (CLI) commands.
|
||||||
|
|
||||||
|
This function processes YOLO settings CLI commands such as reset.
|
||||||
|
It should be called when executing a script with arguments related to YOLO settings management.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (List[str]): A list of command line arguments for YOLO settings management.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
python my_script.py yolo settings reset
|
||||||
|
"""
|
||||||
|
path = USER_CONFIG_DIR / 'settings.yaml' # get SETTINGS YAML file path
|
||||||
|
if any(args) and args[0] == 'reset':
|
||||||
|
path.unlink() # delete the settings file
|
||||||
|
get_settings() # create new settings
|
||||||
|
LOGGER.info('Settings reset successfully') # inform the user that settings have been reset
|
||||||
|
yaml_print(path) # print the current settings
|
||||||
|
|
||||||
|
|
||||||
def entrypoint(debug=''):
|
def entrypoint(debug=''):
|
||||||
"""
|
"""
|
||||||
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
This function is the ultralytics package entrypoint, it's responsible for parsing the command line arguments passed
|
||||||
@ -211,8 +257,10 @@ def entrypoint(debug=''):
|
|||||||
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
||||||
'checks': checks.check_yolo,
|
'checks': checks.check_yolo,
|
||||||
'version': lambda: LOGGER.info(__version__),
|
'version': lambda: LOGGER.info(__version__),
|
||||||
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
|
'settings': lambda: handle_yolo_settings(args[1:]),
|
||||||
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
||||||
|
'hub': lambda: handle_yolo_hub(args[1:]),
|
||||||
|
'login': lambda: handle_yolo_hub(args),
|
||||||
'copy-cfg': copy_default_cfg}
|
'copy-cfg': copy_default_cfg}
|
||||||
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
||||||
|
|
||||||
@ -255,8 +303,8 @@ def entrypoint(debug=''):
|
|||||||
overrides['task'] = a
|
overrides['task'] = a
|
||||||
elif a in MODES:
|
elif a in MODES:
|
||||||
overrides['mode'] = a
|
overrides['mode'] = a
|
||||||
elif a in special:
|
elif a.lower() in special:
|
||||||
special[a]()
|
special[a.lower()]()
|
||||||
return
|
return
|
||||||
elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
|
elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool):
|
||||||
overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
|
overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True
|
||||||
|
@ -68,12 +68,14 @@ class YOLO:
|
|||||||
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None, session=None) -> None:
|
def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the YOLO model.
|
Initializes the YOLO model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (str, Path): model to load or create
|
model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'.
|
||||||
|
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self._reset_callbacks()
|
self._reset_callbacks()
|
||||||
self.predictor = None # reuse predictor
|
self.predictor = None # reuse predictor
|
||||||
@ -85,10 +87,16 @@ class YOLO:
|
|||||||
self.ckpt_path = None
|
self.ckpt_path = None
|
||||||
self.overrides = {} # overrides for trainer object
|
self.overrides = {} # overrides for trainer object
|
||||||
self.metrics = None # validation/training metrics
|
self.metrics = None # validation/training metrics
|
||||||
self.session = session # HUB session
|
self.session = None # HUB session
|
||||||
|
model = str(model).strip() # strip spaces
|
||||||
|
|
||||||
|
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
||||||
|
if model.startswith('https://hub.ultralytics.com/models/'):
|
||||||
|
from ultralytics.hub import HUBTrainingSession
|
||||||
|
self.session = HUBTrainingSession(model)
|
||||||
|
model = self.session.model_file
|
||||||
|
|
||||||
# Load or create new YOLO model
|
# Load or create new YOLO model
|
||||||
model = str(model).strip() # strip spaces
|
|
||||||
suffix = Path(model).suffix
|
suffix = Path(model).suffix
|
||||||
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
|
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
|
||||||
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
|
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||||
@ -280,6 +288,7 @@ class YOLO:
|
|||||||
from ultralytics.yolo.utils.benchmarks import benchmark
|
from ultralytics.yolo.utils.benchmarks import benchmark
|
||||||
overrides = self.model.args.copy()
|
overrides = self.model.args.copy()
|
||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
|
overrides['mode'] = 'benchmark'
|
||||||
overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
|
overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
|
||||||
return benchmark(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device'])
|
return benchmark(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device'])
|
||||||
|
|
||||||
@ -293,6 +302,7 @@ class YOLO:
|
|||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
|
overrides['mode'] = 'export'
|
||||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||||
args.task = self.task
|
args.task = self.task
|
||||||
if args.imgsz == DEFAULT_CFG.imgsz:
|
if args.imgsz == DEFAULT_CFG.imgsz:
|
||||||
@ -309,6 +319,11 @@ class YOLO:
|
|||||||
**kwargs (Any): Any number of arguments representing the training configuration.
|
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||||
"""
|
"""
|
||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
|
if self.session: # Ultralytics HUB session
|
||||||
|
if any(kwargs):
|
||||||
|
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
|
||||||
|
kwargs = self.session.train_args
|
||||||
|
self.session.check_disk_space()
|
||||||
check_pip_update_available()
|
check_pip_update_available()
|
||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
|
@ -277,6 +277,8 @@ class Masks(SimpleClass):
|
|||||||
self.masks = masks # N, h, w
|
self.masks = masks # N, h, w
|
||||||
self.orig_shape = orig_shape
|
self.orig_shape = orig_shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
def segments(self):
|
def segments(self):
|
||||||
# Segments-deprecated (normalized)
|
# Segments-deprecated (normalized)
|
||||||
LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and "
|
LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and "
|
||||||
|
@ -321,10 +321,13 @@ def is_online() -> bool:
|
|||||||
bool: True if connection is successful, False otherwise.
|
bool: True if connection is successful, False otherwise.
|
||||||
"""
|
"""
|
||||||
import socket
|
import socket
|
||||||
with contextlib.suppress(Exception):
|
|
||||||
host = socket.gethostbyname('www.github.com')
|
for server in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS:
|
||||||
socket.create_connection((host, 80), timeout=2)
|
try:
|
||||||
return True
|
socket.create_connection((server, 53), timeout=2) # connect to (server, port=53)
|
||||||
|
return True
|
||||||
|
except (socket.timeout, socket.gaierror, OSError):
|
||||||
|
continue
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@ -586,7 +589,7 @@ def set_sentry():
|
|||||||
logging.getLogger(logger).setLevel(logging.CRITICAL)
|
logging.getLogger(logger).setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
|
||||||
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.2'):
|
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.3'):
|
||||||
"""
|
"""
|
||||||
Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist.
|
Loads a global Ultralytics settings YAML file or creates one with default values if it does not exist.
|
||||||
|
|
||||||
@ -609,8 +612,9 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.2'):
|
|||||||
'datasets_dir': str(datasets_root / 'datasets'), # default datasets directory.
|
'datasets_dir': str(datasets_root / 'datasets'), # default datasets directory.
|
||||||
'weights_dir': str(root / 'weights'), # default weights directory.
|
'weights_dir': str(root / 'weights'), # default weights directory.
|
||||||
'runs_dir': str(root / 'runs'), # default runs directory.
|
'runs_dir': str(root / 'runs'), # default runs directory.
|
||||||
'sync': True, # sync analytics to help with YOLO development
|
|
||||||
'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), # anonymized uuid hash
|
'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), # anonymized uuid hash
|
||||||
|
'sync': True, # sync analytics to help with YOLO development
|
||||||
|
'api_key': '', # Ultralytics HUB API key (https://hub.ultralytics.com/)
|
||||||
'settings_version': version} # Ultralytics settings version
|
'settings_version': version} # Ultralytics settings version
|
||||||
|
|
||||||
with torch_distributed_zero_first(RANK):
|
with torch_distributed_zero_first(RANK):
|
||||||
|
@ -25,7 +25,7 @@ def on_pretrain_routine_end(trainer):
|
|||||||
mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000"
|
mlflow_location = os.environ['MLFLOW_TRACKING_URI'] # "http://192.168.xxx.xxx:5000"
|
||||||
mlflow.set_tracking_uri(mlflow_location)
|
mlflow.set_tracking_uri(mlflow_location)
|
||||||
|
|
||||||
experiment_name = trainer.args.project or 'YOLOv8'
|
experiment_name = trainer.args.project or '/Shared/YOLOv8'
|
||||||
experiment = mlflow.get_experiment_by_name(experiment_name)
|
experiment = mlflow.get_experiment_by_name(experiment_name)
|
||||||
if experiment is None:
|
if experiment is None:
|
||||||
mlflow.create_experiment(experiment_name)
|
mlflow.create_experiment(experiment_name)
|
||||||
@ -33,16 +33,15 @@ def on_pretrain_routine_end(trainer):
|
|||||||
|
|
||||||
prefix = colorstr('MLFlow: ')
|
prefix = colorstr('MLFlow: ')
|
||||||
try:
|
try:
|
||||||
run, active_run = mlflow, mlflow.start_run() if mlflow else None
|
run, active_run = mlflow, mlflow.active_run()
|
||||||
if active_run is not None:
|
if not active_run:
|
||||||
run_id = active_run.info.run_id
|
active_run = mlflow.start_run(experiment_id=experiment.experiment_id)
|
||||||
LOGGER.info(f'{prefix}Using run_id({run_id}) at {mlflow_location}')
|
run_id = active_run.info.run_id
|
||||||
|
LOGGER.info(f'{prefix}Using run_id({run_id}) at {mlflow_location}')
|
||||||
|
run.log_params(vars(trainer.model.args))
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
LOGGER.error(f'{prefix}Failing init - {repr(err)}')
|
LOGGER.error(f'{prefix}Failing init - {repr(err)}')
|
||||||
LOGGER.warning(f'{prefix}Continuing without Mlflow')
|
LOGGER.warning(f'{prefix}Continuing without Mlflow')
|
||||||
run = None
|
|
||||||
|
|
||||||
run.log_params(vars(trainer.model.args))
|
|
||||||
|
|
||||||
|
|
||||||
def on_fit_epoch_end(trainer):
|
def on_fit_epoch_end(trainer):
|
||||||
|
@ -142,7 +142,7 @@ def check_pip_update_available():
|
|||||||
bool: True if an update is available, False otherwise.
|
bool: True if an update is available, False otherwise.
|
||||||
"""
|
"""
|
||||||
if ONLINE and is_pip_package():
|
if ONLINE and is_pip_package():
|
||||||
with contextlib.suppress(ConnectionError):
|
with contextlib.suppress(Exception):
|
||||||
from ultralytics import __version__
|
from ultralytics import __version__
|
||||||
latest = check_latest_pypi_version()
|
latest = check_latest_pypi_version()
|
||||||
if pkg.parse_version(__version__) < pkg.parse_version(latest): # update is available
|
if pkg.parse_version(__version__) < pkg.parse_version(latest): # update is available
|
||||||
|
@ -12,7 +12,7 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ultralytics.yolo.utils import LOGGER, checks, is_online
|
from ultralytics.yolo.utils import LOGGER, checks, emojis, is_online
|
||||||
|
|
||||||
GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \
|
GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \
|
||||||
[f'yolov5{size}u.pt' for size in 'nsmlx'] + \
|
[f'yolov5{size}u.pt' for size in 'nsmlx'] + \
|
||||||
@ -113,9 +113,9 @@ def safe_download(url,
|
|||||||
f.unlink() # remove partial downloads
|
f.unlink() # remove partial downloads
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if i == 0 and not is_online():
|
if i == 0 and not is_online():
|
||||||
raise ConnectionError(f'❌ Download failure for {url}. Environment is not online.') from e
|
raise ConnectionError(emojis(f'❌ Download failure for {url}. Environment is not online.')) from e
|
||||||
elif i >= retry:
|
elif i >= retry:
|
||||||
raise ConnectionError(f'❌ Download failure for {url}. Retry limit reached.') from e
|
raise ConnectionError(emojis(f'❌ Download failure for {url}. Retry limit reached.')) from e
|
||||||
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
||||||
|
|
||||||
if unzip and f.exists() and f.suffix in ('.zip', '.tar', '.gz'):
|
if unzip and f.exists() and f.suffix in ('.zip', '.tar', '.gz'):
|
||||||
|
@ -114,7 +114,7 @@ class Annotator:
|
|||||||
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
|
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
|
||||||
if im_gpu.device != masks.device:
|
if im_gpu.device != masks.device:
|
||||||
im_gpu = im_gpu.to(masks.device)
|
im_gpu = im_gpu.to(masks.device)
|
||||||
colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0
|
colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
|
||||||
colors = colors[:, None, None] # shape(n,1,1,3)
|
colors = colors[:, None, None] # shape(n,1,1,3)
|
||||||
masks = masks.unsqueeze(3) # shape(n,h,w,1)
|
masks = masks.unsqueeze(3) # shape(n,h,w,1)
|
||||||
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
||||||
|
@ -78,7 +78,7 @@ class SegmentationPredictor(DetectionPredictor):
|
|||||||
for j, d in enumerate(reversed(det)):
|
for j, d in enumerate(reversed(det)):
|
||||||
c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
|
c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item())
|
||||||
if self.args.save_txt: # Write to file
|
if self.args.save_txt: # Write to file
|
||||||
seg = mask.segments[len(det) - j - 1].copy().reshape(-1) # reversed mask.segments, (n,2) to (n*2)
|
seg = mask.xyn[len(det) - j - 1].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
|
||||||
line = (c, *seg) + (conf, ) * self.args.save_conf + (() if id is None else (id, ))
|
line = (c, *seg) + (conf, ) * self.args.save_conf + (() if id is None else (id, ))
|
||||||
with open(f'{self.txt_path}.txt', 'a') as f:
|
with open(f'{self.txt_path}.txt', 'a') as f:
|
||||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user