mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Model saved_model export (#151)
This commit is contained in:
parent
d17d1e064d
commit
f8a13c49a0
@ -313,14 +313,11 @@ class Exporter:
|
|||||||
# Simplify
|
# Simplify
|
||||||
if self.args.simplify:
|
if self.args.simplify:
|
||||||
try:
|
try:
|
||||||
cuda = torch.cuda.is_available()
|
check_requirements('onnxsim')
|
||||||
check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
|
import onnxsim
|
||||||
import onnxsim # noqa
|
|
||||||
|
|
||||||
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
|
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
|
||||||
model_onnx, check = onnxsim.simplify(model_onnx)
|
subprocess.run(f'onnxsim {f} {f}', shell=True)
|
||||||
assert check, 'assert check failed'
|
|
||||||
onnx.save(model_onnx, f)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.info(f'{prefix} simplifier failure: {e}')
|
LOGGER.info(f'{prefix} simplifier failure: {e}')
|
||||||
return f, model_onnx
|
return f, model_onnx
|
||||||
@ -460,6 +457,40 @@ class Exporter:
|
|||||||
iou_thres=0.45,
|
iou_thres=0.45,
|
||||||
conf_thres=0.25,
|
conf_thres=0.25,
|
||||||
prefix=colorstr('TensorFlow SavedModel:')):
|
prefix=colorstr('TensorFlow SavedModel:')):
|
||||||
|
|
||||||
|
# YOLOv5 TensorFlow SavedModel export
|
||||||
|
try:
|
||||||
|
import tensorflow as tf # noqa
|
||||||
|
except ImportError:
|
||||||
|
check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
|
||||||
|
import tensorflow as tf # noqa
|
||||||
|
check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon"),
|
||||||
|
cmds="--extra-index-url https://pypi.ngc.nvidia.com ")
|
||||||
|
|
||||||
|
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||||
|
f = str(self.file).replace(self.file.suffix, '_saved_model')
|
||||||
|
|
||||||
|
# Export to ONNX
|
||||||
|
self._export_onnx()
|
||||||
|
onnx = self.file.with_suffix('.onnx')
|
||||||
|
|
||||||
|
# Export to TF SavedModel
|
||||||
|
subprocess.run(f'onnx2tf -i {onnx} --output_signaturedefs -o {f}', shell=True)
|
||||||
|
|
||||||
|
# Load saved_model
|
||||||
|
keras_model = tf.saved_model.load(f, tags=None, options=None)
|
||||||
|
|
||||||
|
return f, keras_model
|
||||||
|
|
||||||
|
@try_export
|
||||||
|
def _export_saved_model_OLD(self,
|
||||||
|
nms=False,
|
||||||
|
agnostic_nms=False,
|
||||||
|
topk_per_class=100,
|
||||||
|
topk_all=100,
|
||||||
|
iou_thres=0.45,
|
||||||
|
conf_thres=0.25,
|
||||||
|
prefix=colorstr('TensorFlow SavedModel:')):
|
||||||
# YOLOv5 TensorFlow SavedModel export
|
# YOLOv5 TensorFlow SavedModel export
|
||||||
try:
|
try:
|
||||||
import tensorflow as tf # noqa
|
import tensorflow as tf # noqa
|
||||||
|
@ -52,8 +52,8 @@ class YOLO:
|
|||||||
# Load or create new YOLO model
|
# Load or create new YOLO model
|
||||||
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)
|
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)
|
||||||
|
|
||||||
def __call__(self, source):
|
def __call__(self, source, **kwargs):
|
||||||
return self.predict(source)
|
return self.predict(source, **kwargs)
|
||||||
|
|
||||||
def _new(self, cfg: str, verbose=True):
|
def _new(self, cfg: str, verbose=True):
|
||||||
"""
|
"""
|
||||||
@ -218,3 +218,4 @@ class YOLO:
|
|||||||
args.pop("name", None)
|
args.pop("name", None)
|
||||||
args.pop("batch", None)
|
args.pop("batch", None)
|
||||||
args.pop("epochs", None)
|
args.pop("epochs", None)
|
||||||
|
args.pop("cache", None)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user