mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
Pass export=True to RTDETRDecoder (#3550)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
adb1a6b091
commit
91905b4b0b
@ -59,7 +59,7 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics.nn.autobackend import check_class_names
|
from ultralytics.nn.autobackend import check_class_names
|
||||||
from ultralytics.nn.modules import C2f, Detect, Segment
|
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
|
||||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, __version__, callbacks, colorstr,
|
from ultralytics.yolo.utils import (DEFAULT_CFG, LINUX, LOGGER, MACOS, __version__, callbacks, colorstr,
|
||||||
@ -157,13 +157,13 @@ class Exporter:
|
|||||||
|
|
||||||
# Load PyTorch model
|
# Load PyTorch model
|
||||||
self.device = select_device('cpu' if self.args.device is None else self.args.device)
|
self.device = select_device('cpu' if self.args.device is None else self.args.device)
|
||||||
|
|
||||||
|
# Checks
|
||||||
|
model.names = check_class_names(model.names)
|
||||||
if self.args.half and onnx and self.device.type == 'cpu':
|
if self.args.half and onnx and self.device.type == 'cpu':
|
||||||
LOGGER.warning('WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0')
|
LOGGER.warning('WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0')
|
||||||
self.args.half = False
|
self.args.half = False
|
||||||
assert not self.args.dynamic, 'half=True not compatible with dynamic=True, i.e. use only one.'
|
assert not self.args.dynamic, 'half=True not compatible with dynamic=True, i.e. use only one.'
|
||||||
|
|
||||||
# Checks
|
|
||||||
model.names = check_class_names(model.names)
|
|
||||||
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
||||||
if self.args.optimize:
|
if self.args.optimize:
|
||||||
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
|
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
|
||||||
@ -185,7 +185,7 @@ class Exporter:
|
|||||||
model.float()
|
model.float()
|
||||||
model = model.fuse()
|
model = model.fuse()
|
||||||
for k, m in model.named_modules():
|
for k, m in model.named_modules():
|
||||||
if isinstance(m, (Detect, Segment)):
|
if isinstance(m, (Detect, RTDETRDecoder)): # Segment and Pose use Detect base class
|
||||||
m.dynamic = self.args.dynamic
|
m.dynamic = self.args.dynamic
|
||||||
m.export = True
|
m.export = True
|
||||||
m.format = self.args.format
|
m.format = self.args.format
|
||||||
|
Loading…
x
Reference in New Issue
Block a user