mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Create Exporter() Class (#117)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
a9dc1637c2
commit
076d73cfaa
@ -16,9 +16,10 @@ pip install -e .
|
|||||||
### 1. CLI
|
### 1. CLI
|
||||||
To simply use the latest Ultralytics YOLO models
|
To simply use the latest Ultralytics YOLO models
|
||||||
```bash
|
```bash
|
||||||
yolo task=detect mode=train model=yolov8n.yaml ...
|
yolo task=detect mode=train model=yolov8n.yaml args=...
|
||||||
classify predict yolov8n-cls.yaml
|
classify predict yolov8n-cls.yaml args=...
|
||||||
segment val yolov8n-seg.yaml
|
segment val yolov8n-seg.yaml args=...
|
||||||
|
export yolov8n.pt format=onnx
|
||||||
```
|
```
|
||||||
### 2. Python SDK
|
### 2. Python SDK
|
||||||
To use pythonic interface of Ultralytics YOLO model
|
To use pythonic interface of Ultralytics YOLO model
|
||||||
|
@ -11,6 +11,7 @@ from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, Bot
|
|||||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||||
GhostBottleneck, GhostConv, Segment)
|
GhostBottleneck, GhostConv, Segment)
|
||||||
from ultralytics.yolo.utils import LOGGER, colorstr
|
from ultralytics.yolo.utils import LOGGER, colorstr
|
||||||
|
from ultralytics.yolo.utils.checks import check_yaml
|
||||||
from ultralytics.yolo.utils.files import yaml_load
|
from ultralytics.yolo.utils.files import yaml_load
|
||||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_state_dicts,
|
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_state_dicts,
|
||||||
make_divisible, model_info, scale_img, time_sync)
|
make_divisible, model_info, scale_img, time_sync)
|
||||||
@ -80,7 +81,7 @@ class DetectionModel(BaseModel):
|
|||||||
# YOLOv5 detection model
|
# YOLOv5 detection model
|
||||||
def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
|
def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.yaml = cfg if isinstance(cfg, dict) else yaml_load(cfg) # cfg dict
|
self.yaml = cfg if isinstance(cfg, dict) else yaml_load(check_yaml(cfg)) # cfg dict
|
||||||
|
|
||||||
# Define model
|
# Define model
|
||||||
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
||||||
|
@ -31,7 +31,7 @@ def cli(cfg):
|
|||||||
elif task == "classify":
|
elif task == "classify":
|
||||||
module = yolo.v8.classify
|
module = yolo.v8.classify
|
||||||
elif task == "export":
|
elif task == "export":
|
||||||
func = yolo.trainer.exporter.export_model
|
func = yolo.engine.exporter.export
|
||||||
else:
|
else:
|
||||||
raise SyntaxError("task not recognized. Choices are `'detect', 'segment', 'classify'`")
|
raise SyntaxError("task not recognized. Choices are `'detect', 'segment', 'classify'`")
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ def cli(cfg):
|
|||||||
elif mode == "predict":
|
elif mode == "predict":
|
||||||
func = module.predict
|
func = module.predict
|
||||||
elif mode == "export":
|
elif mode == "export":
|
||||||
func = yolo.trainer.exporter.export_model
|
func = yolo.engine.exporter.export
|
||||||
else:
|
else:
|
||||||
raise SyntaxError("mode not recognized. Choices are `'train', 'val', 'predict', 'export'`")
|
raise SyntaxError("mode not recognized. Choices are `'train', 'val', 'predict', 'export'`")
|
||||||
func(cfg)
|
func(cfg)
|
||||||
|
@ -29,12 +29,12 @@ image_weights: False # use weighted image selection for training
|
|||||||
rect: False # support rectangular training
|
rect: False # support rectangular training
|
||||||
cos_lr: False # use cosine LR scheduler
|
cos_lr: False # use cosine LR scheduler
|
||||||
close_mosaic: 10 # disable mosaic for final 10 epochs
|
close_mosaic: 10 # disable mosaic for final 10 epochs
|
||||||
|
resume: False
|
||||||
# Segmentation
|
# Segmentation
|
||||||
overlap_mask: True # masks overlap
|
overlap_mask: True # masks overlap
|
||||||
mask_ratio: 4 # mask downsample ratio
|
mask_ratio: 4 # mask downsample ratio
|
||||||
# Classification
|
# Classification
|
||||||
dropout: False # use dropout
|
dropout: False # use dropout
|
||||||
resume: False
|
|
||||||
|
|
||||||
|
|
||||||
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
||||||
@ -65,6 +65,7 @@ agnostic_nms: False # class-agnostic NMS
|
|||||||
retina_masks: False
|
retina_masks: False
|
||||||
|
|
||||||
# Export settings ------------------------------------------------------------------------------------------------------
|
# Export settings ------------------------------------------------------------------------------------------------------
|
||||||
|
format: torchscript
|
||||||
keras: False # use Keras
|
keras: False # use Keras
|
||||||
optimize: False # TorchScript: optimize for mobile
|
optimize: False # TorchScript: optimize for mobile
|
||||||
int8: False # CoreML/TF INT8 quantization
|
int8: False # CoreML/TF INT8 quantization
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -5,7 +5,7 @@ import torch
|
|||||||
from ultralytics import yolo # noqa required for python usage
|
from ultralytics import yolo # noqa required for python usage
|
||||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
|
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
|
||||||
from ultralytics.yolo.configs import get_config
|
from ultralytics.yolo.configs import get_config
|
||||||
from ultralytics.yolo.engine.exporter import export_model
|
from ultralytics.yolo.engine.exporter import Exporter
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER
|
from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER
|
||||||
from ultralytics.yolo.utils.checks import check_yaml
|
from ultralytics.yolo.utils.checks import check_yaml
|
||||||
from ultralytics.yolo.utils.files import yaml_load
|
from ultralytics.yolo.utils.files import yaml_load
|
||||||
@ -164,7 +164,7 @@ class YOLO:
|
|||||||
validator(model=self.model)
|
validator(model=self.model)
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
def export(self, format='', save_dir='', **kwargs):
|
def export(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Export model.
|
Export model.
|
||||||
|
|
||||||
@ -177,36 +177,9 @@ class YOLO:
|
|||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
|
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
|
||||||
args.task = self.task
|
args.task = self.task
|
||||||
args.format = format
|
|
||||||
|
|
||||||
file = self.ckpt or Path(Path(self.cfg).name)
|
exporter = Exporter(overrides=overrides)
|
||||||
if save_dir:
|
exporter(model=self.model)
|
||||||
file = Path(save_dir) / file.name
|
|
||||||
file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
export_model(
|
|
||||||
model=self.model,
|
|
||||||
file=file,
|
|
||||||
data=args.data, # 'dataset.yaml path'
|
|
||||||
imgsz=args.imgsz or (640, 640), # image (height, width)
|
|
||||||
batch_size=1, # batch size
|
|
||||||
device=args.device, # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
|
||||||
format=args.format, # include formats
|
|
||||||
half=args.half or False, # FP16 half-precision export
|
|
||||||
keras=args.keras or False, # use Keras
|
|
||||||
optimize=args.optimize or False, # TorchScript: optimize for mobile
|
|
||||||
int8=args.int8 or False, # CoreML/TF INT8 quantization
|
|
||||||
dynamic=args.dynamic or False, # ONNX/TF/TensorRT: dynamic axes
|
|
||||||
opset=args.opset or 17, # ONNX: opset version
|
|
||||||
verbose=False, # TensorRT: verbose log
|
|
||||||
workspace=args.workspace or 4, # TensorRT: workspace size (GB)
|
|
||||||
nms=False, # TF: add NMS to model
|
|
||||||
agnostic_nms=False, # TF: add agnostic NMS to model
|
|
||||||
topk_per_class=100, # TF.js NMS: topk per class to keep
|
|
||||||
topk_all=100, # TF.js NMS: topk for all classes to keep
|
|
||||||
iou_thres=0.45, # TF.js NMS: IoU threshold
|
|
||||||
conf_thres=0.25, # TF.js NMS: confidence threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
def train(self, **kwargs):
|
def train(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -16,14 +16,14 @@ Usage - formats:
|
|||||||
$ yolo task=... mode=predict --weights yolov8n.pt # PyTorch
|
$ yolo task=... mode=predict --weights yolov8n.pt # PyTorch
|
||||||
yolov8n.torchscript # TorchScript
|
yolov8n.torchscript # TorchScript
|
||||||
yolov8n.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
yolov8n.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
||||||
yolov5s_openvino_model # OpenVINO
|
yolov8n_openvino_model # OpenVINO
|
||||||
yolov8n.engine # TensorRT
|
yolov8n.engine # TensorRT
|
||||||
yolov8n.mlmodel # CoreML (macOS-only)
|
yolov8n.mlmodel # CoreML (macOS-only)
|
||||||
yolov5s_saved_model # TensorFlow SavedModel
|
yolov8n_saved_model # TensorFlow SavedModel
|
||||||
yolov8n.pb # TensorFlow GraphDef
|
yolov8n.pb # TensorFlow GraphDef
|
||||||
yolov8n.tflite # TensorFlow Lite
|
yolov8n.tflite # TensorFlow Lite
|
||||||
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
|
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
|
||||||
yolov5s_paddle_model # PaddlePaddle
|
yolov8n_paddle_model # PaddlePaddle
|
||||||
"""
|
"""
|
||||||
import platform
|
import platform
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -25,14 +25,12 @@ TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
|
|||||||
LOGGING_NAME = 'yolov5'
|
LOGGING_NAME = 'yolov5'
|
||||||
HELP_MSG = \
|
HELP_MSG = \
|
||||||
"""
|
"""
|
||||||
Please refer to below Usage examples for help running YOLOv8
|
Please refer to below Usage examples for help running YOLOv8:
|
||||||
For help visit Ultralytics Community at https://community.ultralytics.com/
|
|
||||||
Submit bug reports to https//github.com/ultralytics/ultralytics
|
|
||||||
|
|
||||||
Install:
|
Install:
|
||||||
pip install ultralytics
|
pip install ultralytics
|
||||||
|
|
||||||
Python usage:
|
Python SDK:
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
|
||||||
model = YOLO.new('yolov8n.yaml') # create a new model from scratch
|
model = YOLO.new('yolov8n.yaml') # create a new model from scratch
|
||||||
@ -42,12 +40,15 @@ HELP_MSG = \
|
|||||||
results = model.predict(source='bus.jpg')
|
results = model.predict(source='bus.jpg')
|
||||||
success = model.export(format='onnx')
|
success = model.export(format='onnx')
|
||||||
|
|
||||||
CLI usage:
|
CLI:
|
||||||
yolo task=detect mode=train model=yolov8n.yaml ...
|
yolo task=detect mode=train model=yolov8n.yaml args...
|
||||||
classify predict yolov8n-cls.yaml
|
classify predict yolov8n-cls.yaml args...
|
||||||
segment val yolov8n-seg.yaml
|
segment val yolov8n-seg.yaml args...
|
||||||
|
export yolov8n.pt format=onnx args...
|
||||||
|
|
||||||
For all arguments see https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/utils/configs/default.yaml
|
Docs: https://docs.ultralytics.com
|
||||||
|
Community: https://community.ultralytics.com
|
||||||
|
GitHub: https://github.com/ultralytics/ultralytics
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Settings
|
# Settings
|
||||||
@ -56,7 +57,6 @@ HELP_MSG = \
|
|||||||
pd.options.display.max_columns = 10
|
pd.options.display.max_columns = 10
|
||||||
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
||||||
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
||||||
os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
|
|
||||||
|
|
||||||
|
|
||||||
def is_colab():
|
def is_colab():
|
||||||
|
@ -36,8 +36,8 @@ def on_val_end(trainer):
|
|||||||
if trainer.epoch == 0:
|
if trainer.epoch == 0:
|
||||||
model_info = {
|
model_info = {
|
||||||
"Parameters": get_num_params(trainer.model),
|
"Parameters": get_num_params(trainer.model),
|
||||||
"GFLOPs": round(get_flops(trainer.model), 1),
|
"GFLOPs": round(get_flops(trainer.model), 3),
|
||||||
"Inference speed (ms/img)": round(trainer.validator.speed[1], 1)}
|
"Inference speed (ms/img)": round(trainer.validator.speed[1], 3)}
|
||||||
Task.current_task().connect(model_info, name='Model')
|
Task.current_task().connect(model_info, name='Model')
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,8 +19,8 @@ def on_val_end(trainer):
|
|||||||
if trainer.epoch == 0:
|
if trainer.epoch == 0:
|
||||||
model_info = {
|
model_info = {
|
||||||
"model/parameters": get_num_params(trainer.model),
|
"model/parameters": get_num_params(trainer.model),
|
||||||
"model/GFLOPs": round(get_flops(trainer.model), 1),
|
"model/GFLOPs": round(get_flops(trainer.model), 3),
|
||||||
"model/speed(ms)": round(trainer.validator.speed[1], 1)}
|
"model/speed(ms)": round(trainer.validator.speed[1], 3)}
|
||||||
wandb.run.log(model_info, step=trainer.epoch + 1)
|
wandb.run.log(model_info, step=trainer.epoch + 1)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user