mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
ultralytics 8.0.129
add YOLOv8 Tencent NCNN export (#3529)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
a254087fcd
commit
abb0939fe8
@ -1,7 +1,7 @@
|
||||
---
|
||||
comments: true
|
||||
description: 'Export mode: Create a deployment-ready YOLOv8 model by converting it to various formats. Export to ONNX or OpenVINO for up to 3x CPU speedup.'
|
||||
keywords: ultralytics docs, YOLOv8, export YOLOv8, YOLOv8 model deployment, exporting YOLOv8, ONNX, OpenVINO, TensorRT, CoreML, TF SavedModel, PaddlePaddle, TorchScript, ONNX format, OpenVINO format, TensorRT format, CoreML format, TF SavedModel format, PaddlePaddle format
|
||||
keywords: ultralytics docs, YOLOv8, export YOLOv8, YOLOv8 model deployment, exporting YOLOv8, ONNX, OpenVINO, TensorRT, CoreML, TF SavedModel, PaddlePaddle, TorchScript, ONNX format, OpenVINO format, TensorRT format, CoreML format, TF SavedModel format, PaddlePaddle format, Tencent NCNN, NCNN
|
||||
---
|
||||
|
||||
<img width="1024" src="https://github.com/ultralytics/assets/raw/main/yolov8/banner-integrations.png">
|
||||
@ -84,4 +84,5 @@ i.e. `format='onnx'` or `format='engine'`.
|
||||
| [TF Lite](https://www.tensorflow.org/lite) | `tflite` | `yolov8n.tflite` | ✅ | `imgsz`, `half`, `int8` |
|
||||
| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n_edgetpu.tflite` | ✅ | `imgsz` |
|
||||
| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n_web_model/` | ✅ | `imgsz` |
|
||||
| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` | ✅ | `imgsz` |
|
||||
| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` | ✅ | `imgsz` |
|
||||
| [NCNN](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n_ncnn_model/` | ✅ | `imgsz` |
|
@ -6,9 +6,7 @@ keywords: Ultralytics, YOLO, YOLOv8, Val, Validation, Hyperparameters, Performan
|
||||
|
||||
<img width="1024" src="https://github.com/ultralytics/assets/raw/main/yolov8/banner-integrations.png">
|
||||
|
||||
**Val mode** is used for validating a YOLOv8 model after it has been trained. In this mode, the model is evaluated on a
|
||||
validation set to measure its accuracy and generalization performance. This mode can be used to tune the hyperparameters
|
||||
of the model to improve its performance.
|
||||
**Val mode** is used for validating a YOLOv8 model after it has been trained. In this mode, the model is evaluated on a validation set to measure its accuracy and generalization performance. This mode can be used to tune the hyperparameters of the model to improve its performance.
|
||||
|
||||
!!! tip "Tip"
|
||||
|
||||
@ -16,8 +14,7 @@ of the model to improve its performance.
|
||||
|
||||
## Usage Examples
|
||||
|
||||
Validate trained YOLOv8n model accuracy on the COCO128 dataset. No argument need to passed as the `model` retains it's
|
||||
training `data` and arguments as model attributes. See Arguments section below for a full list of export arguments.
|
||||
Validate trained YOLOv8n model accuracy on the COCO128 dataset. No argument need to passed as the `model` retains it's training `data` and arguments as model attributes. See Arguments section below for a full list of export arguments.
|
||||
|
||||
!!! example ""
|
||||
|
||||
@ -46,13 +43,7 @@ training `data` and arguments as model attributes. See Arguments section below f
|
||||
|
||||
## Arguments
|
||||
|
||||
Validation settings for YOLO models refer to the various hyperparameters and configurations used to
|
||||
evaluate the model's performance on a validation dataset. These settings can affect the model's performance, speed, and
|
||||
accuracy. Some common YOLO validation settings include the batch size, the frequency with which validation is performed
|
||||
during training, and the metrics used to evaluate the model's performance. Other factors that may affect the validation
|
||||
process include the size and composition of the validation dataset and the specific task the model is being used for. It
|
||||
is important to carefully tune and experiment with these settings to ensure that the model is performing well on the
|
||||
validation dataset and to detect and prevent overfitting.
|
||||
Validation settings for YOLO models refer to the various hyperparameters and configurations used to evaluate the model's performance on a validation dataset. These settings can affect the model's performance, speed, and accuracy. Some common YOLO validation settings include the batch size, the frequency with which validation is performed during training, and the metrics used to evaluate the model's performance. Other factors that may affect the validation process include the size and composition of the validation dataset and the specific task the model is being used for. It is important to carefully tune and experiment with these settings to ensure that the model is performing well on the validation dataset and to detect and prevent overfitting.
|
||||
|
||||
| Key | Value | Description |
|
||||
|---------------|---------|--------------------------------------------------------------------|
|
||||
@ -70,23 +61,4 @@ validation dataset and to detect and prevent overfitting.
|
||||
| `plots` | `False` | show plots during training |
|
||||
| `rect` | `False` | rectangular val with each batch collated for minimum padding |
|
||||
| `split` | `val` | dataset split to use for validation, i.e. 'val', 'test' or 'train' |
|
||||
|
||||
## Export Formats
|
||||
|
||||
Available YOLOv8 export formats are in the table below. You can export to any format using the `format` argument,
|
||||
i.e. `format='onnx'` or `format='engine'`.
|
||||
|
||||
| Format | `format` Argument | Model | Metadata | Arguments |
|
||||
|--------------------------------------------------------------------|-------------------|---------------------------|----------|-----------------------------------------------------|
|
||||
| [PyTorch](https://pytorch.org/) | - | `yolov8n.pt` | ✅ | - |
|
||||
| [TorchScript](https://pytorch.org/docs/stable/jit.html) | `torchscript` | `yolov8n.torchscript` | ✅ | `imgsz`, `optimize` |
|
||||
| [ONNX](https://onnx.ai/) | `onnx` | `yolov8n.onnx` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `opset` |
|
||||
| [OpenVINO](https://docs.openvino.ai/latest/index.html) | `openvino` | `yolov8n_openvino_model/` | ✅ | `imgsz`, `half` |
|
||||
| [TensorRT](https://developer.nvidia.com/tensorrt) | `engine` | `yolov8n.engine` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `workspace` |
|
||||
| [CoreML](https://github.com/apple/coremltools) | `coreml` | `yolov8n.mlmodel` | ✅ | `imgsz`, `half`, `int8`, `nms` |
|
||||
| [TF SavedModel](https://www.tensorflow.org/guide/saved_model) | `saved_model` | `yolov8n_saved_model/` | ✅ | `imgsz`, `keras` |
|
||||
| [TF GraphDef](https://www.tensorflow.org/api_docs/python/tf/Graph) | `pb` | `yolov8n.pb` | ❌ | `imgsz` |
|
||||
| [TF Lite](https://www.tensorflow.org/lite) | `tflite` | `yolov8n.tflite` | ✅ | `imgsz`, `half`, `int8` |
|
||||
| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n_edgetpu.tflite` | ✅ | `imgsz` |
|
||||
| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n_web_model/` | ✅ | `imgsz` |
|
||||
| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` | ✅ | `imgsz` |
|
||||
|
|
@ -176,5 +176,6 @@ i.e. `yolo predict model=yolov8n-cls.onnx`. Usage examples are shown for your mo
|
||||
| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n-cls_edgetpu.tflite` | ✅ | `imgsz` |
|
||||
| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n-cls_web_model/` | ✅ | `imgsz` |
|
||||
| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n-cls_paddle_model/` | ✅ | `imgsz` |
|
||||
| [NCNN](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n-cls_ncnn_model/` | ✅ | `imgsz` |
|
||||
|
||||
See full `export` details in the [Export](https://docs.ultralytics.com/modes/export/) page.
|
@ -167,5 +167,6 @@ Available YOLOv8 export formats are in the table below. You can predict or valid
|
||||
| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n_edgetpu.tflite` | ✅ | `imgsz` |
|
||||
| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n_web_model/` | ✅ | `imgsz` |
|
||||
| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n_paddle_model/` | ✅ | `imgsz` |
|
||||
| [NCNN](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n_ncnn_model/` | ✅ | `imgsz` |
|
||||
|
||||
See full `export` details in the [Export](https://docs.ultralytics.com/modes/export/) page.
|
@ -181,5 +181,6 @@ i.e. `yolo predict model=yolov8n-pose.onnx`. Usage examples are shown for your m
|
||||
| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n-pose_edgetpu.tflite` | ✅ | `imgsz` |
|
||||
| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n-pose_web_model/` | ✅ | `imgsz` |
|
||||
| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n-pose_paddle_model/` | ✅ | `imgsz` |
|
||||
| [NCNN](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n-pose_ncnn_model/` | ✅ | `imgsz` |
|
||||
|
||||
See full `export` details in the [Export](https://docs.ultralytics.com/modes/export/) page.
|
@ -181,5 +181,6 @@ i.e. `yolo predict model=yolov8n-seg.onnx`. Usage examples are shown for your mo
|
||||
| [TF Edge TPU](https://coral.ai/docs/edgetpu/models-intro/) | `edgetpu` | `yolov8n-seg_edgetpu.tflite` | ✅ | `imgsz` |
|
||||
| [TF.js](https://www.tensorflow.org/js) | `tfjs` | `yolov8n-seg_web_model/` | ✅ | `imgsz` |
|
||||
| [PaddlePaddle](https://github.com/PaddlePaddle) | `paddle` | `yolov8n-seg_paddle_model/` | ✅ | `imgsz` |
|
||||
| [NCNN](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n-seg_ncnn_model/` | ✅ | `imgsz` |
|
||||
|
||||
See full `export` details in the [Export](https://docs.ultralytics.com/modes/export/) page.
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = '8.0.128'
|
||||
__version__ = '8.0.129'
|
||||
|
||||
from ultralytics.hub import start
|
||||
from ultralytics.vit.rtdetr import RTDETR
|
||||
|
@ -79,7 +79,8 @@ class AutoBackend(nn.Module):
|
||||
super().__init__()
|
||||
w = str(weights[0] if isinstance(weights, list) else weights)
|
||||
nn_module = isinstance(weights, torch.nn.Module)
|
||||
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
|
||||
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \
|
||||
self._model_type(w)
|
||||
fp16 &= pt or jit or onnx or engine or nn_module or triton # FP16
|
||||
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
|
||||
stride = 32 # default stride
|
||||
@ -237,7 +238,7 @@ class AutoBackend(nn.Module):
|
||||
meta_file = model.namelist()[0]
|
||||
metadata = ast.literal_eval(model.read(meta_file).decode('utf-8'))
|
||||
elif tfjs: # TF.js
|
||||
raise NotImplementedError('YOLOv8 TF.js inference is not supported')
|
||||
raise NotImplementedError('YOLOv8 TF.js inference is not currently supported.')
|
||||
elif paddle: # PaddlePaddle
|
||||
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
|
||||
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
|
||||
@ -252,6 +253,8 @@ class AutoBackend(nn.Module):
|
||||
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
|
||||
output_names = predictor.get_output_names()
|
||||
metadata = w.parents[1] / 'metadata.yaml'
|
||||
elif ncnn: # PaddlePaddle
|
||||
raise NotImplementedError('YOLOv8 NCNN inference is not currently supported.')
|
||||
elif triton: # NVIDIA Triton Inference Server
|
||||
LOGGER.info('Triton Inference Server not supported...')
|
||||
'''
|
||||
|
@ -16,6 +16,7 @@ TensorFlow Lite | `tflite` | yolov8n.tflite
|
||||
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
|
||||
TensorFlow.js | `tfjs` | yolov8n_web_model/
|
||||
PaddlePaddle | `paddle` | yolov8n_paddle_model/
|
||||
NCNN | `ncnn` | yolov8n_ncnn_model/
|
||||
|
||||
Requirements:
|
||||
$ pip install ultralytics[export]
|
||||
@ -50,6 +51,7 @@ TensorFlow.js:
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import warnings
|
||||
@ -62,9 +64,10 @@ from ultralytics.nn.autobackend import check_class_names
|
||||
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
|
||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, __version__, callbacks, colorstr,
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, __version__, callbacks, colorstr,
|
||||
get_default_args, yaml_save)
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset, get_github_assets
|
||||
from ultralytics.yolo.utils.files import file_size
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import get_latest_opset, select_device, smart_inference_mode
|
||||
@ -87,7 +90,8 @@ def export_formats():
|
||||
['TensorFlow Lite', 'tflite', '.tflite', True, False],
|
||||
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
|
||||
['TensorFlow.js', 'tfjs', '_web_model', True, False],
|
||||
['PaddlePaddle', 'paddle', '_paddle_model', True, True], ]
|
||||
['PaddlePaddle', 'paddle', '_paddle_model', True, True],
|
||||
['NCNN', 'ncnn', '_ncnn_model', True, True], ]
|
||||
return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
||||
|
||||
|
||||
@ -153,7 +157,7 @@ class Exporter:
|
||||
flags = [x == format for x in fmts]
|
||||
if sum(flags) != 1:
|
||||
raise ValueError(f"Invalid export format='{format}'. Valid formats are {fmts}")
|
||||
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans
|
||||
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
|
||||
|
||||
# Load PyTorch model
|
||||
self.device = select_device('cpu' if self.args.device is None else self.args.device)
|
||||
@ -231,7 +235,7 @@ class Exporter:
|
||||
|
||||
# Exports
|
||||
f = [''] * len(fmts) # exported filenames
|
||||
if jit: # TorchScript
|
||||
if jit or ncnn: # TorchScript
|
||||
f[0], _ = self.export_torchscript()
|
||||
if engine: # TensorRT required before ONNX
|
||||
f[1], _ = self.export_engine()
|
||||
@ -254,6 +258,8 @@ class Exporter:
|
||||
f[9], _ = self.export_tfjs()
|
||||
if paddle: # PaddlePaddle
|
||||
f[10], _ = self.export_paddle()
|
||||
if ncnn: # NCNN
|
||||
f[11], _ = self.export_ncnn()
|
||||
|
||||
# Finish
|
||||
f = [str(x) for x in f if x] # filter out '' and None
|
||||
@ -394,6 +400,57 @@ class Exporter:
|
||||
yaml_save(Path(f) / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return f, None
|
||||
|
||||
@try_export
|
||||
def export_ncnn(self, prefix=colorstr('NCNN:')):
|
||||
"""
|
||||
YOLOv8 NCNN export using PNNX https://github.com/pnnx/pnnx.
|
||||
"""
|
||||
check_requirements('ncnn') # requires NCNN
|
||||
import ncnn # noqa
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with NCNN {ncnn.__version__}...')
|
||||
f = Path(str(self.file).replace(self.file.suffix, f'_ncnn_model{os.sep}'))
|
||||
f_ts = str(self.file.with_suffix('.torchscript'))
|
||||
|
||||
if Path('./pnnx').is_file():
|
||||
pnnx = './pnnx'
|
||||
elif (ROOT / 'pnnx').is_file():
|
||||
pnnx = ROOT / 'pnnx'
|
||||
else:
|
||||
LOGGER.warning(
|
||||
f'{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from '
|
||||
'https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory '
|
||||
f'or in {ROOT}. See PNNX repo for full installation instructions.')
|
||||
_, assets = get_github_assets(repo='pnnx/pnnx')
|
||||
asset = [x for x in assets if ('macos' if MACOS else 'ubuntu' if LINUX else 'windows') in x][0]
|
||||
attempt_download_asset(asset, repo='pnnx/pnnx', release='latest')
|
||||
unzip_dir = Path(asset).with_suffix('')
|
||||
pnnx = ROOT / 'pnnx' # new location
|
||||
(unzip_dir / 'pnnx').rename(pnnx) # move binary to ROOT
|
||||
shutil.rmtree(unzip_dir) # delete unzip dir
|
||||
Path(asset).unlink() # delete zip
|
||||
pnnx.chmod(0o777) # set read, write, and execute permissions for everyone
|
||||
|
||||
cmd = [
|
||||
str(pnnx),
|
||||
f_ts,
|
||||
f'pnnxparam={f / "model.pnnx.param"}',
|
||||
f'pnnxbin={f / "model.pnnx.bin"}',
|
||||
f'pnnxpy={f / "model_pnnx.py"}',
|
||||
f'pnnxonnx={f / "model.pnnx.onnx"}',
|
||||
f'ncnnparam={f / "model.ncnn.param"}',
|
||||
f'ncnnbin={f / "model.ncnn.bin"}',
|
||||
f'ncnnpy={f / "model_ncnn.py"}',
|
||||
f'fp16={int(self.args.half)}',
|
||||
f'device={self.device.type}',
|
||||
f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', ]
|
||||
f.mkdir(exist_ok=True) # make ncnn_model directory
|
||||
LOGGER.info(f"{prefix} running '{' '.join(cmd)}'")
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
|
||||
return str(f), None
|
||||
|
||||
@try_export
|
||||
def export_coreml(self, prefix=colorstr('CoreML:')):
|
||||
"""YOLOv8 CoreML export."""
|
||||
|
@ -21,6 +21,7 @@ TensorFlow Lite | `tflite` | yolov8n.tflite
|
||||
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
|
||||
TensorFlow.js | `tfjs` | yolov8n_web_model/
|
||||
PaddlePaddle | `paddle` | yolov8n_paddle_model/
|
||||
NCNN | `ncnn` | yolov8n_ncnn_model/
|
||||
"""
|
||||
|
||||
import glob
|
||||
@ -98,7 +99,7 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
||||
|
||||
# Predict
|
||||
assert model.task != 'pose' or i != 7, 'GraphDef Pose inference is not supported'
|
||||
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
|
||||
assert i not in (9, 10, 12), 'inference not supported' # Edge TPU, TF.js and NCNN are unsupported
|
||||
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
|
||||
if not (ROOT / 'assets/bus.jpg').exists():
|
||||
download(url='https://ultralytics.com/images/bus.jpg', dir=ROOT / 'assets')
|
||||
|
@ -8,6 +8,7 @@ import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@ -235,13 +236,16 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
|
||||
|
||||
if s:
|
||||
if install and AUTOINSTALL: # check environment variable
|
||||
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
|
||||
pkgs = file or requirements # missing packages
|
||||
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
|
||||
try:
|
||||
t = time.time()
|
||||
assert is_online(), 'AutoUpdate skipped (offline)'
|
||||
LOGGER.info(subprocess.check_output(f'pip install --no-cache {s} {cmds}', shell=True).decode())
|
||||
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \
|
||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
||||
LOGGER.info(s)
|
||||
dt = time.time() - t
|
||||
LOGGER.info(
|
||||
f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n"
|
||||
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{prefix} ❌ {e}')
|
||||
return False
|
||||
|
@ -189,7 +189,7 @@ def safe_download(url,
|
||||
|
||||
if unzip and f.exists() and f.suffix in ('', '.zip', '.tar', '.gz'):
|
||||
unzip_dir = dir or f.parent # unzip to dir if provided else unzip in place
|
||||
LOGGER.info(f'Unzipping {f} to {unzip_dir}...')
|
||||
LOGGER.info(f'Unzipping {f} to {unzip_dir.absolute()}...')
|
||||
if is_zipfile(f):
|
||||
unzip_dir = unzip_file(file=f, path=unzip_dir) # unzip
|
||||
elif f.suffix == '.tar':
|
||||
@ -201,17 +201,18 @@ def safe_download(url,
|
||||
return unzip_dir
|
||||
|
||||
|
||||
def get_github_assets(repo='ultralytics/assets', version='latest'):
|
||||
"""Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])."""
|
||||
if version != 'latest':
|
||||
version = f'tags/{version}' # i.e. tags/v6.2
|
||||
response = requests.get(f'https://api.github.com/repos/{repo}/releases/{version}').json() # github api
|
||||
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
|
||||
|
||||
|
||||
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
"""Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc."""
|
||||
from ultralytics.yolo.utils import SETTINGS # scoped for circular import
|
||||
|
||||
def github_assets(repository, version='latest'):
|
||||
"""Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])."""
|
||||
if version != 'latest':
|
||||
version = f'tags/{version}' # i.e. tags/v6.2
|
||||
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
|
||||
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
|
||||
|
||||
# YOLOv3/5u updates
|
||||
file = str(file)
|
||||
file = checks.check_yolov5u_filename(file)
|
||||
@ -235,10 +236,10 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
# GitHub assets
|
||||
assets = GITHUB_ASSET_NAMES
|
||||
try:
|
||||
tag, assets = github_assets(repo, release)
|
||||
tag, assets = get_github_assets(repo, release)
|
||||
except Exception:
|
||||
try:
|
||||
tag, assets = github_assets(repo) # latest release
|
||||
tag, assets = get_github_assets(repo) # latest release
|
||||
except Exception:
|
||||
try:
|
||||
tag = subprocess.check_output(['git', 'tag']).decode().split()[-1]
|
||||
|
Loading…
x
Reference in New Issue
Block a user