mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Check PyTorch model status for all YOLO
methods (#945)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
fd5be10c66
commit
20fe708f31
9
.github/workflows/ci.yaml
vendored
9
.github/workflows/ci.yaml
vendored
@ -29,7 +29,7 @@ jobs:
|
|||||||
- os: ubuntu-latest
|
- os: ubuntu-latest
|
||||||
python-version: '3.8' # torch 1.7.0 requires python >=3.6, <=3.8
|
python-version: '3.8' # torch 1.7.0 requires python >=3.6, <=3.8
|
||||||
model: yolov8n
|
model: yolov8n
|
||||||
torch: '1.7.0' # min torch version CI https://pypi.org/project/torchvision/
|
torch: '1.8.0' # min torch version CI https://pypi.org/project/torchvision/
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v4
|
||||||
@ -48,13 +48,12 @@ jobs:
|
|||||||
- name: Install requirements
|
- name: Install requirements
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip wheel
|
python -m pip install --upgrade pip wheel
|
||||||
if [ "${{ matrix.torch }}" == "1.7.0" ]; then
|
if [ "${{ matrix.torch }}" == "1.8.0" ]; then
|
||||||
pip install -r requirements.txt torch==1.7.0 torchvision==0.8.1 --extra-index-url https://download.pytorch.org/whl/cpu
|
pip install -e . torch==1.8.0 torchvision==0.9.0 onnx openvino-dev>=2022.3 pytest --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
else
|
else
|
||||||
pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
pip install -e . onnx openvino-dev>=2022.3 pytest --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
fi
|
fi
|
||||||
# pip install ultralytics (production)
|
# pip install ultralytics (production)
|
||||||
pip install -e . pytest
|
|
||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
- name: Check environment
|
- name: Check environment
|
||||||
run: |
|
run: |
|
||||||
|
2
.github/workflows/cla.yml
vendored
2
.github/workflows/cla.yml
vendored
@ -18,7 +18,7 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: "CLA Assistant"
|
- name: "CLA Assistant"
|
||||||
if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I sign the CLA') || github.event_name == 'pull_request_target'
|
if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I sign the CLA') || github.event_name == 'pull_request_target'
|
||||||
uses: contributor-assistant/github-action@v2.2.1
|
uses: contributor-assistant/github-action@v2.3.0
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
# must be repository secret token
|
# must be repository secret token
|
||||||
|
@ -114,8 +114,8 @@ We are still working on several parts of YOLOv8! We aim to have these completed
|
|||||||
to par with YOLOv5, including export and inference to all the same formats. We are also writing a YOLOv8 paper which we
|
to par with YOLOv5, including export and inference to all the same formats. We are also writing a YOLOv8 paper which we
|
||||||
will submit to [arxiv.org](https://arxiv.org) once complete.
|
will submit to [arxiv.org](https://arxiv.org) once complete.
|
||||||
|
|
||||||
- [ ] TensorFlow exports
|
- [x] TensorFlow exports
|
||||||
- [ ] DDP resume
|
- [x] DDP resume
|
||||||
- [ ] [arxiv.org](https://arxiv.org) paper
|
- [ ] [arxiv.org](https://arxiv.org) paper
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@ -246,8 +246,7 @@ YOLOv8 is available under two different licenses:
|
|||||||
|
|
||||||
## <div align="center">Contact</div>
|
## <div align="center">Contact</div>
|
||||||
|
|
||||||
For YOLOv8 bugs and feature requests please visit [GitHub Issues](https://github.com/ultralytics/ultralytics/issues).
|
For YOLOv8 bug reports and feature requests please visit [GitHub Issues](https://github.com/ultralytics/ultralytics/issues) or the [Ultralytics Community Forum](https://community.ultralytics.com/).
|
||||||
For professional support please [Contact Us](https://ultralytics.com/contact).
|
|
||||||
|
|
||||||
<br>
|
<br>
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
@ -101,8 +101,8 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式
|
|||||||
|
|
||||||
我们仍在努力完善 YOLOv8 的几个部分!我们的目标是尽快完成这些工作,使 YOLOv8 的功能设置达到YOLOv5 的水平,包括对所有相同格式的导出和推理。我们还在写一篇 YOLOv8 的论文,一旦完成,我们将提交给 [arxiv.org](https://arxiv.org)。
|
我们仍在努力完善 YOLOv8 的几个部分!我们的目标是尽快完成这些工作,使 YOLOv8 的功能设置达到YOLOv5 的水平,包括对所有相同格式的导出和推理。我们还在写一篇 YOLOv8 的论文,一旦完成,我们将提交给 [arxiv.org](https://arxiv.org)。
|
||||||
|
|
||||||
- [ ] TensorFlow 导出
|
- [x] TensorFlow 导出
|
||||||
- [ ] DDP 恢复训练
|
- [x] DDP 恢复训练
|
||||||
- [ ] [arxiv.org](https://arxiv.org) 论文
|
- [ ] [arxiv.org](https://arxiv.org) 论文
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@ -214,7 +214,7 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式
|
|||||||
|
|
||||||
## <div align="center">联系我们</div>
|
## <div align="center">联系我们</div>
|
||||||
|
|
||||||
若发现 YOLOv8 的 Bug 或有功能需求,请访问 [GitHub 问题](https://github.com/ultralytics/ultralytics/issues)。如需专业支持,请 [联系我们](https://ultralytics.com/contact)。
|
请访问 [GitHub Issues](https://github.com/ultralytics/ultralytics/issues) 或 [Ultralytics Community Forum](https://community.ultralytis.com) 以报告 YOLOv8 错误和请求功能。
|
||||||
|
|
||||||
<br>
|
<br>
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
17
docs/SECURITY.md
Normal file
17
docs/SECURITY.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
At [Ultralytics](https://ultralytics.com), the security of our users' data and systems is of utmost importance. To ensure the safety and security of our [open-source projects](https://github.com/ultralytics), we have implemented several measures to detect and prevent security vulnerabilities.
|
||||||
|
|
||||||
|
[](https://snyk.io/advisor/python/ultralytics)
|
||||||
|
|
||||||
|
## Snyk Scanning
|
||||||
|
|
||||||
|
We use [Snyk](https://snyk.io/advisor/python/ultralytics) to regularly scan the YOLOv8 repository for vulnerabilities and security issues. Our goal is to identify and remediate any potential threats as soon as possible, to minimize any risks to our users.
|
||||||
|
|
||||||
|
## GitHub CodeQL Scanning
|
||||||
|
|
||||||
|
In addition to our Snyk scans, we also use GitHub's [CodeQL](https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/about-code-scanning-with-codeql) scans to proactively identify and address security vulnerabilities.
|
||||||
|
|
||||||
|
## Reporting Security Issues
|
||||||
|
|
||||||
|
If you suspect or discover a security vulnerability in the YOLOv8 repository, please let us know immediately. You can reach out to us directly via our [contact form](https://ultralytics.com/contact) or via [security@ultralytics.com](mailto:security@ultralytics.com). Our security team will investigate and respond as soon as possible.
|
||||||
|
|
||||||
|
We appreciate your help in keeping the YOLOv8 repository secure and safe for everyone.
|
@ -122,3 +122,4 @@ nav:
|
|||||||
- Results: reference/results.md
|
- Results: reference/results.md
|
||||||
- ultralytics.nn: reference/nn.md
|
- ultralytics.nn: reference/nn.md
|
||||||
- Operations: reference/ops.md
|
- Operations: reference/ops.md
|
||||||
|
- Security: SECURITY.md
|
||||||
|
@ -48,18 +48,18 @@ def test_val_classify():
|
|||||||
|
|
||||||
# Predict checks -------------------------------------------------------------------------------------------------------
|
# Predict checks -------------------------------------------------------------------------------------------------------
|
||||||
def test_predict_detect():
|
def test_predict_detect():
|
||||||
run(f"yolo predict detect model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
|
run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
|
||||||
run(f"yolo predict detect model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32")
|
run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32")
|
||||||
run(f"yolo predict detect model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape.mov imgsz=32")
|
run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32")
|
||||||
run(f"yolo predict detect model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait.mov imgsz=32")
|
run(f"yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32")
|
||||||
|
|
||||||
|
|
||||||
def test_predict_segment():
|
def test_predict_segment():
|
||||||
run(f"yolo predict segment model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32")
|
run(f"yolo predict model={MODEL}-seg.pt source={ROOT / 'assets'} imgsz=32")
|
||||||
|
|
||||||
|
|
||||||
def test_predict_classify():
|
def test_predict_classify():
|
||||||
run(f"yolo predict classify model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32")
|
run(f"yolo predict model={MODEL}-cls.pt source={ROOT / 'assets'} imgsz=32")
|
||||||
|
|
||||||
|
|
||||||
# Export checks --------------------------------------------------------------------------------------------------------
|
# Export checks --------------------------------------------------------------------------------------------------------
|
||||||
|
@ -18,7 +18,6 @@ SOURCE = ROOT / 'assets/bus.jpg'
|
|||||||
|
|
||||||
def test_model_forward():
|
def test_model_forward():
|
||||||
model = YOLO(CFG)
|
model = YOLO(CFG)
|
||||||
model.predict(SOURCE)
|
|
||||||
model(SOURCE)
|
model(SOURCE)
|
||||||
|
|
||||||
|
|
||||||
@ -38,11 +37,10 @@ def test_model_fuse():
|
|||||||
|
|
||||||
def test_predict_dir():
|
def test_predict_dir():
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
model.predict(source=ROOT / "assets")
|
model(source=ROOT / "assets")
|
||||||
|
|
||||||
|
|
||||||
def test_predict_img():
|
def test_predict_img():
|
||||||
|
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
img = Image.open(str(SOURCE))
|
img = Image.open(str(SOURCE))
|
||||||
output = model(source=img, save=True, verbose=True) # PIL
|
output = model(source=img, save=True, verbose=True) # PIL
|
||||||
@ -106,22 +104,26 @@ def test_export_torchscript():
|
|||||||
print(export_formats())
|
print(export_formats())
|
||||||
|
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
model.export(format='torchscript')
|
f = model.export(format='torchscript')
|
||||||
|
YOLO(f)(SOURCE) # exported model inference
|
||||||
|
|
||||||
|
|
||||||
def test_export_onnx():
|
def test_export_onnx():
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
model.export(format='onnx')
|
f = model.export(format='onnx')
|
||||||
|
YOLO(f)(SOURCE) # exported model inference
|
||||||
|
|
||||||
|
|
||||||
def test_export_openvino():
|
def test_export_openvino():
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
model.export(format='openvino')
|
f = model.export(format='openvino')
|
||||||
|
YOLO(f)(SOURCE) # exported model inference
|
||||||
|
|
||||||
|
|
||||||
def test_export_coreml():
|
def test_export_coreml():
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
model.export(format='coreml')
|
model.export(format='coreml')
|
||||||
|
# YOLO(f)(SOURCE) # model prediction only supported on macOS
|
||||||
|
|
||||||
|
|
||||||
def test_export_paddle(enabled=False):
|
def test_export_paddle(enabled=False):
|
||||||
@ -140,6 +142,7 @@ def test_workflow():
|
|||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
model.train(data="coco8.yaml", epochs=1, imgsz=32)
|
model.train(data="coco8.yaml", epochs=1, imgsz=32)
|
||||||
model.val()
|
model.val()
|
||||||
|
print(model.metrics)
|
||||||
model.predict(SOURCE)
|
model.predict(SOURCE)
|
||||||
model.export(format="onnx", opset=12) # export a model to ONNX format
|
model.export(format="onnx", opset=12) # export a model to ONNX format
|
||||||
|
|
||||||
@ -164,6 +167,3 @@ def test_predict_callback_and_setup():
|
|||||||
print('test_callback', bs)
|
print('test_callback', bs)
|
||||||
boxes = result.boxes # Boxes object for bbox outputs
|
boxes = result.boxes # Boxes object for bbox outputs
|
||||||
print(boxes)
|
print(boxes)
|
||||||
|
|
||||||
|
|
||||||
test_predict_img()
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.0.35"
|
__version__ = "8.0.36"
|
||||||
|
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
from ultralytics.yolo.utils.checks import check_yolo as checks
|
from ultralytics.yolo.utils.checks import check_yolo as checks
|
||||||
|
@ -5,12 +5,12 @@ import requests
|
|||||||
from ultralytics.hub.auth import Auth
|
from ultralytics.hub.auth import Auth
|
||||||
from ultralytics.hub.session import HubTrainingSession
|
from ultralytics.hub.session import HubTrainingSession
|
||||||
from ultralytics.hub.utils import split_key
|
from ultralytics.hub.utils import split_key
|
||||||
from ultralytics.yolo.engine.exporter import export_formats
|
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_LIST
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
from ultralytics.yolo.utils import LOGGER, PREFIX, emojis
|
from ultralytics.yolo.utils import LOGGER, PREFIX, emojis
|
||||||
|
|
||||||
# Define all export formats
|
# Define all export formats
|
||||||
EXPORT_FORMATS = list(export_formats()['Argument'][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
|
EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ["ultralytics_tflite", "ultralytics_coreml"]
|
||||||
|
|
||||||
|
|
||||||
def start(key=""):
|
def start(key=""):
|
||||||
@ -69,7 +69,7 @@ def reset_model(key=""):
|
|||||||
|
|
||||||
def export_model(key="", format="torchscript"):
|
def export_model(key="", format="torchscript"):
|
||||||
# Export a model to all formats
|
# Export a model to all formats
|
||||||
assert format in EXPORT_FORMATS, f"Unsupported export format '{format}' passed, valid formats are {EXPORT_FORMATS}"
|
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
|
||||||
api_key, model_id = split_key(key)
|
api_key, model_id = split_key(key)
|
||||||
r = requests.post("https://api.ultralytics.com/export",
|
r = requests.post("https://api.ultralytics.com/export",
|
||||||
json={
|
json={
|
||||||
@ -82,7 +82,7 @@ def export_model(key="", format="torchscript"):
|
|||||||
|
|
||||||
def get_export(key="", format="torchscript"):
|
def get_export(key="", format="torchscript"):
|
||||||
# Get an exported model dictionary with download URL
|
# Get an exported model dictionary with download URL
|
||||||
assert format in EXPORT_FORMATS, f"Unsupported export format '{format}' passed, valid formats are {EXPORT_FORMATS}"
|
assert format in EXPORT_FORMATS_HUB, f"Unsupported export format '{format}', valid formats are {EXPORT_FORMATS_HUB}"
|
||||||
api_key, model_id = split_key(key)
|
api_key, model_id = split_key(key)
|
||||||
r = requests.post("https://api.ultralytics.com/get-export",
|
r = requests.post("https://api.ultralytics.com/get-export",
|
||||||
json={
|
json={
|
||||||
|
@ -193,7 +193,7 @@ class AutoBackend(nn.Module):
|
|||||||
from tflite_runtime.interpreter import Interpreter, load_delegate
|
from tflite_runtime.interpreter import Interpreter, load_delegate
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
|
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
|
||||||
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
|
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
|
||||||
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
|
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
|
||||||
delegate = {
|
delegate = {
|
||||||
@ -232,8 +232,10 @@ class AutoBackend(nn.Module):
|
|||||||
nhwc = model.runtime.startswith("tensorflow")
|
nhwc = model.runtime.startswith("tensorflow")
|
||||||
'''
|
'''
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"ERROR: '{w}' is not a supported format. For supported formats see "
|
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_TABLE
|
||||||
f"https://docs.ultralytics.com/reference/nn/")
|
raise TypeError(f"model='{w}' is not a supported model format. "
|
||||||
|
"See https://docs.ultralytics.com/tasks/detection/#export for help."
|
||||||
|
f"\n\n{EXPORT_FORMATS_TABLE}")
|
||||||
|
|
||||||
# class names
|
# class names
|
||||||
if 'names' not in locals(): # names missing
|
if 'names' not in locals(): # names missing
|
||||||
|
@ -356,7 +356,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||||
|
|
||||||
# Model compatibility updates
|
# Model compatibility updates
|
||||||
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
model.args = args # attach args to model
|
||||||
model.pt_path = weights # attach *.pt file path to model
|
model.pt_path = weights # attach *.pt file path to model
|
||||||
model.task = guess_model_task(model)
|
model.task = guess_model_task(model)
|
||||||
if not hasattr(model, 'stride'):
|
if not hasattr(model, 'stride'):
|
||||||
|
@ -12,8 +12,8 @@ from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_P
|
|||||||
IterableSimpleNamespace, __version__, checks, colorstr, yaml_load, yaml_print)
|
IterableSimpleNamespace, __version__, checks, colorstr, yaml_load, yaml_print)
|
||||||
|
|
||||||
CLI_HELP_MSG = \
|
CLI_HELP_MSG = \
|
||||||
"""
|
f"""
|
||||||
YOLOv8 'yolo' CLI commands use the following syntax:
|
Arguments received: {str(['yolo'] + sys.argv[1:])}. Note that Ultralytics 'yolo' commands use the following syntax:
|
||||||
|
|
||||||
yolo TASK MODE ARGS
|
yolo TASK MODE ARGS
|
||||||
|
|
||||||
@ -64,9 +64,7 @@ CFG_BOOL_KEYS = {
|
|||||||
|
|
||||||
def cfg2dict(cfg):
|
def cfg2dict(cfg):
|
||||||
"""
|
"""
|
||||||
Convert a configuration object to a dictionary.
|
Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
|
||||||
|
|
||||||
This function converts a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
|
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
cfg (str) or (Path) or (SimpleNamespace): Configuration object to be converted to a dictionary.
|
cfg (str) or (Path) or (SimpleNamespace): Configuration object to be converted to a dictionary.
|
||||||
@ -143,8 +141,9 @@ def check_cfg_mismatch(base: Dict, custom: Dict, e=None):
|
|||||||
if mismatched:
|
if mismatched:
|
||||||
string = ''
|
string = ''
|
||||||
for x in mismatched:
|
for x in mismatched:
|
||||||
matches = get_close_matches(x, base)
|
matches = get_close_matches(x, base) # key list
|
||||||
match_str = f"Similar arguments are {matches}." if matches else ''
|
matches = [f"{k}={DEFAULT_CFG_DICT[k]}" if DEFAULT_CFG_DICT[k] is not None else k for k in matches] # k=v
|
||||||
|
match_str = f"Similar arguments are i.e. {matches}." if matches else ''
|
||||||
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
|
string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n"
|
||||||
raise SyntaxError(string + CLI_HELP_MSG) from e
|
raise SyntaxError(string + CLI_HELP_MSG) from e
|
||||||
|
|
||||||
@ -265,7 +264,7 @@ def entrypoint(debug=''):
|
|||||||
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
|
||||||
elif mode not in modes:
|
elif mode not in modes:
|
||||||
if mode != 'checks':
|
if mode != 'checks':
|
||||||
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {modes}.")
|
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {modes}.\n{CLI_HELP_MSG}")
|
||||||
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
|
LOGGER.warning("WARNING ⚠️ 'yolo mode=checks' is deprecated. Use 'yolo checks' instead.")
|
||||||
checks.check_yolo()
|
checks.check_yolo()
|
||||||
return
|
return
|
||||||
|
@ -682,7 +682,8 @@ def v8_transforms(dataset, imgsz, hyp):
|
|||||||
# Classification augmentations -----------------------------------------------------------------------------------------
|
# Classification augmentations -----------------------------------------------------------------------------------------
|
||||||
def classify_transforms(size=224):
|
def classify_transforms(size=224):
|
||||||
# Transforms to apply if albumentations not installed
|
# Transforms to apply if albumentations not installed
|
||||||
assert isinstance(size, int), f"ERROR: classify_transforms size {size} must be integer, not (list, tuple)"
|
if not isinstance(size, int):
|
||||||
|
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
|
||||||
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
# T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||||
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
|
||||||
|
|
||||||
|
@ -48,7 +48,6 @@ TensorFlow.js:
|
|||||||
$ ln -s ../../yolov5/yolov8n_web_model public/yolov8n_web_model
|
$ ln -s ../../yolov5/yolov8n_web_model public/yolov8n_web_model
|
||||||
$ npm start
|
$ npm start
|
||||||
"""
|
"""
|
||||||
import contextlib
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
@ -74,7 +73,7 @@ from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, __version__, callbacks,
|
|||||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
||||||
from ultralytics.yolo.utils.files import file_size
|
from ultralytics.yolo.utils.files import file_size
|
||||||
from ultralytics.yolo.utils.ops import Profile
|
from ultralytics.yolo.utils.ops import Profile
|
||||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode, get_latest_opset
|
||||||
|
|
||||||
MACOS = platform.system() == 'Darwin' # macOS environment
|
MACOS = platform.system() == 'Darwin' # macOS environment
|
||||||
|
|
||||||
@ -97,6 +96,10 @@ def export_formats():
|
|||||||
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
||||||
|
|
||||||
|
|
||||||
|
EXPORT_FORMATS_LIST = list(export_formats()['Argument'][1:])
|
||||||
|
EXPORT_FORMATS_TABLE = str(export_formats())
|
||||||
|
|
||||||
|
|
||||||
def try_export(inner_func):
|
def try_export(inner_func):
|
||||||
# YOLOv8 export decorator, i..e @try_export
|
# YOLOv8 export decorator, i..e @try_export
|
||||||
inner_args = get_default_args(inner_func)
|
inner_args = get_default_args(inner_func)
|
||||||
@ -244,7 +247,7 @@ class Exporter:
|
|||||||
agnostic_nms=self.args.agnostic_nms)
|
agnostic_nms=self.args.agnostic_nms)
|
||||||
if edgetpu:
|
if edgetpu:
|
||||||
f[8], _ = self._export_edgetpu()
|
f[8], _ = self._export_edgetpu()
|
||||||
self._add_tflite_metadata(f[8] or f[7], num_outputs=len(self.output_shape))
|
self._add_tflite_metadata(f[8] or f[7])
|
||||||
if tfjs:
|
if tfjs:
|
||||||
f[9], _ = self._export_tfjs()
|
f[9], _ = self._export_tfjs()
|
||||||
if paddle: # PaddlePaddle
|
if paddle: # PaddlePaddle
|
||||||
@ -253,11 +256,11 @@ class Exporter:
|
|||||||
# Finish
|
# Finish
|
||||||
f = [str(x) for x in f if x] # filter out '' and None
|
f = [str(x) for x in f if x] # filter out '' and None
|
||||||
if any(f):
|
if any(f):
|
||||||
s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models"
|
f = str(Path(f[-1]))
|
||||||
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
|
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
|
||||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||||
f"\nPredict: yolo task={model.task} mode=predict model={f[-1]} {s}"
|
f"\nPredict: yolo task={model.task} mode=predict model={f}"
|
||||||
f"\nValidate: yolo task={model.task} mode=val model={f[-1]} {s}"
|
f"\nValidate: yolo task={model.task} mode=val model={f}"
|
||||||
f"\nVisualize: https://netron.app")
|
f"\nVisualize: https://netron.app")
|
||||||
|
|
||||||
self.run_callbacks("on_export_end")
|
self.run_callbacks("on_export_end")
|
||||||
@ -304,7 +307,7 @@ class Exporter:
|
|||||||
self.im.cpu() if dynamic else self.im,
|
self.im.cpu() if dynamic else self.im,
|
||||||
f,
|
f,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
opset_version=self.args.opset,
|
opset_version=self.args.opset or get_latest_opset(),
|
||||||
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
|
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
|
||||||
input_names=['images'],
|
input_names=['images'],
|
||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
@ -507,6 +510,10 @@ class Exporter:
|
|||||||
# Export to TF SavedModel
|
# Export to TF SavedModel
|
||||||
subprocess.run(f'onnx2tf -i {onnx} --output_signaturedefs -o {f}', shell=True)
|
subprocess.run(f'onnx2tf -i {onnx} --output_signaturedefs -o {f}', shell=True)
|
||||||
|
|
||||||
|
# Add TFLite metadata
|
||||||
|
for tflite_file in Path(f).rglob('*.tflite'):
|
||||||
|
self._add_tflite_metadata(tflite_file)
|
||||||
|
|
||||||
# Load saved_model
|
# Load saved_model
|
||||||
keras_model = tf.saved_model.load(f, tags=None, options=None)
|
keras_model = tf.saved_model.load(f, tags=None, options=None)
|
||||||
|
|
||||||
@ -661,17 +668,20 @@ class Exporter:
|
|||||||
r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
|
r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
|
||||||
r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
||||||
r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
r'"Identity.?.?": {"name": "Identity.?.?"}, '
|
||||||
r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
|
r'"Identity.?.?": {"name": "Identity.?.?"}}}',
|
||||||
|
r'{"outputs": {"Identity": {"name": "Identity"}, '
|
||||||
r'"Identity_1": {"name": "Identity_1"}, '
|
r'"Identity_1": {"name": "Identity_1"}, '
|
||||||
r'"Identity_2": {"name": "Identity_2"}, '
|
r'"Identity_2": {"name": "Identity_2"}, '
|
||||||
r'"Identity_3": {"name": "Identity_3"}}}', f_json.read_text())
|
r'"Identity_3": {"name": "Identity_3"}}}',
|
||||||
|
f_json.read_text(),
|
||||||
|
)
|
||||||
j.write(subst)
|
j.write(subst)
|
||||||
return f, None
|
return f, None
|
||||||
|
|
||||||
def _add_tflite_metadata(self, file, num_outputs):
|
def _add_tflite_metadata(self, file):
|
||||||
# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
|
# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
|
||||||
with contextlib.suppress(ImportError):
|
check_requirements('tflite_support')
|
||||||
# check_requirements('tflite_support')
|
|
||||||
from tflite_support import flatbuffers # noqa
|
from tflite_support import flatbuffers # noqa
|
||||||
from tflite_support import metadata as _metadata # noqa
|
from tflite_support import metadata as _metadata # noqa
|
||||||
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
|
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
|
||||||
@ -687,7 +697,7 @@ class Exporter:
|
|||||||
|
|
||||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
subgraph = _metadata_fb.SubGraphMetadataT()
|
||||||
subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
|
subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
|
||||||
subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs
|
subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * len(self.output_shape)
|
||||||
model_meta.subgraphMetadata = [subgraph]
|
model_meta.subgraphMetadata = [subgraph]
|
||||||
|
|
||||||
b = flatbuffers.Builder(0)
|
b = flatbuffers.Builder(0)
|
||||||
|
@ -6,11 +6,11 @@ from typing import List
|
|||||||
|
|
||||||
from ultralytics import yolo # noqa
|
from ultralytics import yolo # noqa
|
||||||
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
|
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
|
||||||
guess_model_task)
|
guess_model_task, nn)
|
||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.engine.exporter import Exporter
|
from ultralytics.yolo.engine.exporter import Exporter
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, callbacks, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, callbacks, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
|
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml
|
||||||
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
||||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||||
|
|
||||||
@ -55,19 +55,16 @@ class YOLO:
|
|||||||
self.cfg = None # if loaded from *.yaml
|
self.cfg = None # if loaded from *.yaml
|
||||||
self.ckpt_path = None
|
self.ckpt_path = None
|
||||||
self.overrides = {} # overrides for trainer object
|
self.overrides = {} # overrides for trainer object
|
||||||
|
self.metrics_data = None
|
||||||
|
|
||||||
# Load or create new YOLO model
|
# Load or create new YOLO model
|
||||||
suffix = Path(model).suffix
|
suffix = Path(model).suffix
|
||||||
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
|
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
|
||||||
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
|
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||||
try:
|
|
||||||
if suffix == '.yaml':
|
if suffix == '.yaml':
|
||||||
self._new(model)
|
self._new(model)
|
||||||
else:
|
else:
|
||||||
self._load(model)
|
self._load(model)
|
||||||
except Exception as e:
|
|
||||||
raise NotImplementedError(f"Unable to load model='{model}'. "
|
|
||||||
f"As an example try model='yolov8n.pt' or model='yolov8n.yaml'") from e
|
|
||||||
|
|
||||||
def __call__(self, source=None, stream=False, **kwargs):
|
def __call__(self, source=None, stream=False, **kwargs):
|
||||||
return self.predict(source, stream, **kwargs)
|
return self.predict(source, stream, **kwargs)
|
||||||
@ -100,15 +97,27 @@ class YOLO:
|
|||||||
self.overrides = self.model.args
|
self.overrides = self.model.args
|
||||||
self._reset_ckpt_args(self.overrides)
|
self._reset_ckpt_args(self.overrides)
|
||||||
else:
|
else:
|
||||||
|
check_file(weights)
|
||||||
self.model, self.ckpt = weights, None
|
self.model, self.ckpt = weights, None
|
||||||
self.task = guess_model_task(weights)
|
self.task = guess_model_task(weights)
|
||||||
self.ckpt_path = weights
|
self.ckpt_path = weights
|
||||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
||||||
|
|
||||||
|
def _check_is_pytorch_model(self):
|
||||||
|
"""
|
||||||
|
Raises TypeError is model is not a PyTorch model
|
||||||
|
"""
|
||||||
|
if not isinstance(self.model, nn.Module):
|
||||||
|
raise TypeError(f"model='{self.model}' must be a PyTorch model, but is a different type. PyTorch models "
|
||||||
|
f"can be used to train, val, predict and export, i.e. "
|
||||||
|
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
||||||
|
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
Resets the model modules.
|
Resets the model modules.
|
||||||
"""
|
"""
|
||||||
|
self._check_is_pytorch_model()
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
if hasattr(m, 'reset_parameters'):
|
if hasattr(m, 'reset_parameters'):
|
||||||
m.reset_parameters()
|
m.reset_parameters()
|
||||||
@ -122,9 +131,11 @@ class YOLO:
|
|||||||
Args:
|
Args:
|
||||||
verbose (bool): Controls verbosity.
|
verbose (bool): Controls verbosity.
|
||||||
"""
|
"""
|
||||||
|
self._check_is_pytorch_model()
|
||||||
self.model.info(verbose=verbose)
|
self.model.info(verbose=verbose)
|
||||||
|
|
||||||
def fuse(self):
|
def fuse(self):
|
||||||
|
self._check_is_pytorch_model()
|
||||||
self.model.fuse()
|
self.model.fuse()
|
||||||
|
|
||||||
def predict(self, source=None, stream=False, **kwargs):
|
def predict(self, source=None, stream=False, **kwargs):
|
||||||
@ -176,6 +187,8 @@ class YOLO:
|
|||||||
|
|
||||||
validator = self.ValidatorClass(args=args)
|
validator = self.ValidatorClass(args=args)
|
||||||
validator(model=self.model)
|
validator(model=self.model)
|
||||||
|
self.metrics_data = validator.metrics
|
||||||
|
|
||||||
return validator.metrics
|
return validator.metrics
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
@ -186,7 +199,7 @@ class YOLO:
|
|||||||
Args:
|
Args:
|
||||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
||||||
"""
|
"""
|
||||||
|
self._check_is_pytorch_model()
|
||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||||
@ -196,7 +209,7 @@ class YOLO:
|
|||||||
if args.batch == DEFAULT_CFG.batch:
|
if args.batch == DEFAULT_CFG.batch:
|
||||||
args.batch = 1 # default to 1 if not modified
|
args.batch = 1 # default to 1 if not modified
|
||||||
exporter = Exporter(overrides=args)
|
exporter = Exporter(overrides=args)
|
||||||
exporter(model=self.model)
|
return exporter(model=self.model)
|
||||||
|
|
||||||
def train(self, **kwargs):
|
def train(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -205,6 +218,7 @@ class YOLO:
|
|||||||
Args:
|
Args:
|
||||||
**kwargs (Any): Any number of arguments representing the training configuration.
|
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||||
"""
|
"""
|
||||||
|
self._check_is_pytorch_model()
|
||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
if kwargs.get("cfg"):
|
if kwargs.get("cfg"):
|
||||||
@ -226,6 +240,7 @@ class YOLO:
|
|||||||
if RANK in {0, -1}:
|
if RANK in {0, -1}:
|
||||||
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
|
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
|
||||||
self.overrides = self.model.args
|
self.overrides = self.model.args
|
||||||
|
self.metrics_data = self.trainer.validator.metrics
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
"""
|
"""
|
||||||
@ -234,15 +249,14 @@ class YOLO:
|
|||||||
Args:
|
Args:
|
||||||
device (str): device
|
device (str): device
|
||||||
"""
|
"""
|
||||||
|
self._check_is_pytorch_model()
|
||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
|
|
||||||
def _assign_ops_from_task(self):
|
def _assign_ops_from_task(self):
|
||||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
|
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
|
||||||
# warning: eval is unsafe. Use with caution
|
|
||||||
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
|
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
|
||||||
validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
|
validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
|
||||||
predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))
|
predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))
|
||||||
|
|
||||||
return model_class, trainer_class, validator_class, predictor_class
|
return model_class, trainer_class, validator_class, predictor_class
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -250,7 +264,7 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
Returns class names of the loaded model.
|
Returns class names of the loaded model.
|
||||||
"""
|
"""
|
||||||
return self.model.names
|
return self.model.names if hasattr(self.model, 'names') else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def transforms(self):
|
def transforms(self):
|
||||||
@ -259,6 +273,16 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def metrics(self):
|
||||||
|
"""
|
||||||
|
Returns metrics if computed
|
||||||
|
"""
|
||||||
|
if not self.metrics_data:
|
||||||
|
LOGGER.info("No metrics data found! Run training or validation operation first.")
|
||||||
|
|
||||||
|
return self.metrics_data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_callback(event: str, func):
|
def add_callback(event: str, func):
|
||||||
"""
|
"""
|
||||||
@ -269,5 +293,5 @@ class YOLO:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _reset_ckpt_args(args):
|
def _reset_ckpt_args(args):
|
||||||
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
|
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
|
||||||
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots':
|
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset':
|
||||||
args.pop(arg, None)
|
args.pop(arg, None)
|
||||||
|
@ -35,6 +35,7 @@ import torch
|
|||||||
from ultralytics.nn.autobackend import AutoBackend
|
from ultralytics.nn.autobackend import AutoBackend
|
||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.data import load_inference_source
|
from ultralytics.yolo.data import load_inference_source
|
||||||
|
from ultralytics.yolo.data.augment import classify_transforms
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
||||||
from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
|
from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
|
||||||
from ultralytics.yolo.utils.files import increment_path
|
from ultralytics.yolo.utils.files import increment_path
|
||||||
@ -121,8 +122,12 @@ class BasePredictor:
|
|||||||
|
|
||||||
def setup_source(self, source):
|
def setup_source(self, source):
|
||||||
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||||
|
if self.args.task == 'classify':
|
||||||
|
transforms = getattr(self.model.model, 'transforms', classify_transforms(self.imgsz[0]))
|
||||||
|
else: # predict, segment
|
||||||
|
transforms = None
|
||||||
self.dataset = load_inference_source(source=source,
|
self.dataset = load_inference_source(source=source,
|
||||||
transforms=getattr(self.model.model, 'transforms', None),
|
transforms=transforms,
|
||||||
imgsz=self.imgsz,
|
imgsz=self.imgsz,
|
||||||
vid_stride=self.args.vid_stride,
|
vid_stride=self.args.vid_stride,
|
||||||
stride=self.model.stride,
|
stride=self.model.stride,
|
||||||
|
@ -217,19 +217,18 @@ class BaseTrainer:
|
|||||||
|
|
||||||
# Optimizer
|
# Optimizer
|
||||||
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||||
self.args.weight_decay *= self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||||
self.optimizer = self.build_optimizer(model=self.model,
|
self.optimizer = self.build_optimizer(model=self.model,
|
||||||
name=self.args.optimizer,
|
name=self.args.optimizer,
|
||||||
lr=self.args.lr0,
|
lr=self.args.lr0,
|
||||||
momentum=self.args.momentum,
|
momentum=self.args.momentum,
|
||||||
decay=self.args.weight_decay)
|
decay=weight_decay)
|
||||||
# Scheduler
|
# Scheduler
|
||||||
if self.args.cos_lr:
|
if self.args.cos_lr:
|
||||||
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
||||||
else:
|
else:
|
||||||
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||||
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
|
||||||
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
||||||
|
|
||||||
# dataloaders
|
# dataloaders
|
||||||
@ -242,6 +241,7 @@ class BaseTrainer:
|
|||||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) # TODO: init metrics for plot_results()?
|
||||||
self.ema = ModelEMA(self.model)
|
self.ema = ModelEMA(self.model)
|
||||||
self.resume_training(ckpt)
|
self.resume_training(ckpt)
|
||||||
|
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||||
self.run_callbacks("on_pretrain_routine_end")
|
self.run_callbacks("on_pretrain_routine_end")
|
||||||
|
|
||||||
def _do_train(self, rank=-1, world_size=1):
|
def _do_train(self, rank=-1, world_size=1):
|
||||||
@ -555,6 +555,12 @@ class BaseTrainer:
|
|||||||
self.epochs += ckpt['epoch'] # finetune additional epochs
|
self.epochs += ckpt['epoch'] # finetune additional epochs
|
||||||
self.best_fitness = best_fitness
|
self.best_fitness = best_fitness
|
||||||
self.start_epoch = start_epoch
|
self.start_epoch = start_epoch
|
||||||
|
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||||
|
self.console.info("Closing dataloader mosaic")
|
||||||
|
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||||
|
self.train_loader.dataset.mosaic = False
|
||||||
|
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||||
|
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||||
|
@ -234,17 +234,17 @@ def check_yolov5u_filename(file: str):
|
|||||||
return file
|
return file
|
||||||
|
|
||||||
|
|
||||||
def check_file(file, suffix=''):
|
def check_file(file, suffix='', download=True):
|
||||||
# Search/download file (if necessary) and return path
|
# Search/download file (if necessary) and return path
|
||||||
check_suffix(file, suffix) # optional
|
check_suffix(file, suffix) # optional
|
||||||
file = str(file) # convert to string
|
file = str(file) # convert to string
|
||||||
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
|
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
|
||||||
if not file or ('://' not in file and Path(file).is_file()): # exists ('://' check required in Windows Python<3.10)
|
if not file or ('://' not in file and Path(file).exists()): # exists ('://' check required in Windows Python<3.10)
|
||||||
return file
|
return file
|
||||||
elif file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')): # download
|
elif download and file.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')): # download
|
||||||
url = file # warning: Pathlib turns :// -> :/
|
url = file # warning: Pathlib turns :// -> :/
|
||||||
file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
|
file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
|
||||||
if Path(file).is_file():
|
if Path(file).exists():
|
||||||
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
||||||
else:
|
else:
|
||||||
downloads.safe_download(url=url, file=file, unzip=False)
|
downloads.safe_download(url=url, file=file, unzip=False)
|
||||||
|
@ -44,11 +44,17 @@ def generate_ddp_file(trainer):
|
|||||||
|
|
||||||
def generate_ddp_command(world_size, trainer):
|
def generate_ddp_command(world_size, trainer):
|
||||||
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
|
||||||
file = generate_ddp_file(trainer) if sys.argv[0].endswith('yolo') else os.path.abspath(sys.argv[0])
|
|
||||||
|
# Get file and args (do not use sys.argv due to security vulnerability)
|
||||||
|
exclude_args = ['save_dir']
|
||||||
|
args = [f"{k}={v}" for k, v in vars(trainer.args).items() if k not in exclude_args]
|
||||||
|
file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0])
|
||||||
|
|
||||||
|
# Build command
|
||||||
torch_distributed_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
|
torch_distributed_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
|
||||||
cmd = [
|
cmd = [
|
||||||
sys.executable, "-m", torch_distributed_cmd, "--nproc_per_node", f"{world_size}", "--master_port",
|
sys.executable, "-m", torch_distributed_cmd, "--nproc_per_node", f"{world_size}", "--master_port",
|
||||||
f"{find_free_network_port()}", file] + sys.argv[1:]
|
f"{find_free_network_port()}", file] + args
|
||||||
return cmd, file
|
return cmd, file
|
||||||
|
|
||||||
|
|
||||||
|
@ -242,6 +242,11 @@ def copy_attr(a, b, include=(), exclude=()):
|
|||||||
setattr(a, k, v)
|
setattr(a, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
def get_latest_opset():
|
||||||
|
# Return max supported ONNX opset by this version of torch
|
||||||
|
return max(int(k[14:]) for k in vars(torch.onnx) if 'symbolic_opset' in k) # opset
|
||||||
|
|
||||||
|
|
||||||
def intersect_dicts(da, db, exclude=()):
|
def intersect_dicts(da, db, exclude=()):
|
||||||
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
||||||
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user