mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Model saved_model export (#151)
This commit is contained in:
parent
d17d1e064d
commit
f8a13c49a0
@ -313,14 +313,11 @@ class Exporter:
|
||||
# Simplify
|
||||
if self.args.simplify:
|
||||
try:
|
||||
cuda = torch.cuda.is_available()
|
||||
check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
|
||||
import onnxsim # noqa
|
||||
check_requirements('onnxsim')
|
||||
import onnxsim
|
||||
|
||||
LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
|
||||
model_onnx, check = onnxsim.simplify(model_onnx)
|
||||
assert check, 'assert check failed'
|
||||
onnx.save(model_onnx, f)
|
||||
subprocess.run(f'onnxsim {f} {f}', shell=True)
|
||||
except Exception as e:
|
||||
LOGGER.info(f'{prefix} simplifier failure: {e}')
|
||||
return f, model_onnx
|
||||
@ -460,6 +457,40 @@ class Exporter:
|
||||
iou_thres=0.45,
|
||||
conf_thres=0.25,
|
||||
prefix=colorstr('TensorFlow SavedModel:')):
|
||||
|
||||
# YOLOv5 TensorFlow SavedModel export
|
||||
try:
|
||||
import tensorflow as tf # noqa
|
||||
except ImportError:
|
||||
check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
|
||||
import tensorflow as tf # noqa
|
||||
check_requirements(("onnx", "onnx2tf", "sng4onnx", "onnxsim", "onnx_graphsurgeon"),
|
||||
cmds="--extra-index-url https://pypi.ngc.nvidia.com ")
|
||||
|
||||
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
|
||||
f = str(self.file).replace(self.file.suffix, '_saved_model')
|
||||
|
||||
# Export to ONNX
|
||||
self._export_onnx()
|
||||
onnx = self.file.with_suffix('.onnx')
|
||||
|
||||
# Export to TF SavedModel
|
||||
subprocess.run(f'onnx2tf -i {onnx} --output_signaturedefs -o {f}', shell=True)
|
||||
|
||||
# Load saved_model
|
||||
keras_model = tf.saved_model.load(f, tags=None, options=None)
|
||||
|
||||
return f, keras_model
|
||||
|
||||
@try_export
|
||||
def _export_saved_model_OLD(self,
|
||||
nms=False,
|
||||
agnostic_nms=False,
|
||||
topk_per_class=100,
|
||||
topk_all=100,
|
||||
iou_thres=0.45,
|
||||
conf_thres=0.25,
|
||||
prefix=colorstr('TensorFlow SavedModel:')):
|
||||
# YOLOv5 TensorFlow SavedModel export
|
||||
try:
|
||||
import tensorflow as tf # noqa
|
||||
|
@ -52,8 +52,8 @@ class YOLO:
|
||||
# Load or create new YOLO model
|
||||
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)
|
||||
|
||||
def __call__(self, source):
|
||||
return self.predict(source)
|
||||
def __call__(self, source, **kwargs):
|
||||
return self.predict(source, **kwargs)
|
||||
|
||||
def _new(self, cfg: str, verbose=True):
|
||||
"""
|
||||
@ -218,3 +218,4 @@ class YOLO:
|
||||
args.pop("name", None)
|
||||
args.pop("batch", None)
|
||||
args.pop("epochs", None)
|
||||
args.pop("cache", None)
|
||||
|
Loading…
x
Reference in New Issue
Block a user