mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Merge model()
and model.predict()
(#146)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
99275814f1
commit
46cb657b64
@ -1,19 +1,17 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
from ultralytics.yolo.utils import ROOT, SETTINGS
|
from ultralytics.yolo.utils import ROOT, SETTINGS
|
||||||
|
|
||||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
|
||||||
CFG = 'yolov8n.yaml'
|
CFG = 'yolov8n.yaml'
|
||||||
|
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
||||||
|
SOURCE = ROOT / 'assets/bus.jpg'
|
||||||
|
|
||||||
|
|
||||||
def test_model_forward():
|
def test_model_forward():
|
||||||
model = YOLO(CFG)
|
model = YOLO(CFG)
|
||||||
img = torch.rand(1, 3, 320, 320)
|
model.predict(SOURCE)
|
||||||
model.forward(img)
|
model(SOURCE)
|
||||||
model(img)
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_info():
|
def test_model_info():
|
||||||
@ -43,15 +41,13 @@ def test_val():
|
|||||||
def test_train_scratch():
|
def test_train_scratch():
|
||||||
model = YOLO(CFG)
|
model = YOLO(CFG)
|
||||||
model.train(data="coco128.yaml", epochs=1, imgsz=32)
|
model.train(data="coco128.yaml", epochs=1, imgsz=32)
|
||||||
img = torch.rand(1, 3, 320, 320)
|
model(SOURCE)
|
||||||
model(img)
|
|
||||||
|
|
||||||
|
|
||||||
def test_train_pretrained():
|
def test_train_pretrained():
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
model.train(data="coco128.yaml", epochs=1, imgsz=32)
|
model.train(data="coco128.yaml", epochs=1, imgsz=32)
|
||||||
img = torch.rand(1, 3, 320, 320)
|
model(SOURCE)
|
||||||
model(img)
|
|
||||||
|
|
||||||
|
|
||||||
def test_export_torchscript():
|
def test_export_torchscript():
|
||||||
@ -100,11 +96,3 @@ def test_export_paddle():
|
|||||||
def test_all_model_yamls():
|
def test_all_model_yamls():
|
||||||
for m in list((ROOT / 'yolo/v8/models').rglob('*.yaml')):
|
for m in list((ROOT / 'yolo/v8/models').rglob('*.yaml')):
|
||||||
YOLO(m.name)
|
YOLO(m.name)
|
||||||
|
|
||||||
|
|
||||||
# def run_all_tests(): # do not name function test_...
|
|
||||||
# pass
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# run_all_tests()
|
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from ultralytics import yolo # noqa
|
from ultralytics import yolo # noqa
|
||||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
|
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
|
||||||
from ultralytics.yolo.configs import get_config
|
from ultralytics.yolo.configs import get_config
|
||||||
from ultralytics.yolo.engine.exporter import Exporter
|
from ultralytics.yolo.engine.exporter import Exporter
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
|
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
|
||||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
|
||||||
|
|
||||||
@ -55,6 +53,9 @@ class YOLO:
|
|||||||
# Load or create new YOLO model
|
# Load or create new YOLO model
|
||||||
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)
|
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)
|
||||||
|
|
||||||
|
def __call__(self, source):
|
||||||
|
return self.predict(source)
|
||||||
|
|
||||||
def _new(self, cfg: str, verbose=True):
|
def _new(self, cfg: str, verbose=True):
|
||||||
"""
|
"""
|
||||||
Initializes a new model and infers the task type from the model definitions.
|
Initializes a new model and infers the task type from the model definitions.
|
||||||
@ -211,14 +212,6 @@ class YOLO:
|
|||||||
|
|
||||||
return model_class, trainer_class, validator_class, predictor_class
|
return model_class, trainer_class, validator_class, predictor_class
|
||||||
|
|
||||||
@smart_inference_mode()
|
|
||||||
def __call__(self, imgs):
|
|
||||||
device = next(self.model.parameters()).device # get model device
|
|
||||||
return self.model(imgs.to(device))
|
|
||||||
|
|
||||||
def forward(self, imgs):
|
|
||||||
return self.__call__(imgs)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reset_ckpt_args(args):
|
def _reset_ckpt_args(args):
|
||||||
args.pop("device", None)
|
args.pop("device", None)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user