mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
Model interface enhancement (#106)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
38d6df55cb
commit
384f0ef1c6
6
.github/workflows/ci.yaml
vendored
6
.github/workflows/ci.yaml
vendored
@ -91,14 +91,14 @@ jobs:
|
|||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
run: |
|
run: |
|
||||||
yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=1 imgsz=64
|
yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=1 imgsz=64
|
||||||
yolo task=detect mode=val model=runs/train/exp/weights/last.pt imgsz=64
|
yolo task=detect mode=val model=runs/detect/train/weights/last.pt imgsz=64
|
||||||
- name: Test segmentation # TODO: segmentation CI
|
- name: Test segmentation # TODO: segmentation CI
|
||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
run: |
|
run: |
|
||||||
# yolo task=segment mode=train model=yolov8n-seg.yaml data=coco128-seg.yaml epochs=1 imgsz=64
|
# yolo task=segment mode=train model=yolov8n-seg.yaml data=coco128-seg.yaml epochs=1 imgsz=64
|
||||||
# yolo task=segment mode=val model=runs/train/exp2/weights/last.pt data=coco128-seg.yaml imgsz=64
|
# yolo task=segment mode=val model=runs/segment/train/weights/last.pt data=coco128-seg.yaml imgsz=64
|
||||||
- name: Test classification # TODO: change to exp3 on Segmentation CI update
|
- name: Test classification # TODO: change to exp3 on Segmentation CI update
|
||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
run: |
|
run: |
|
||||||
yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 imgsz=32
|
yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 imgsz=32
|
||||||
yolo task=classify mode=val model=runs/train/exp2/weights/last.pt data=mnist160
|
yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics.yolo import YOLO
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
|
||||||
def test_model_forward():
|
def test_model_forward():
|
||||||
model = YOLO()
|
model = YOLO()
|
||||||
model.new("yolov8n-seg.yaml")
|
model.new("yolov8n.yaml")
|
||||||
img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512)
|
img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512)
|
||||||
model.forward(img)
|
model.forward(img)
|
||||||
model(img)
|
model(img)
|
||||||
@ -15,7 +15,7 @@ def test_model_info():
|
|||||||
model = YOLO()
|
model = YOLO()
|
||||||
model.new("yolov8n.yaml")
|
model.new("yolov8n.yaml")
|
||||||
model.info()
|
model.info()
|
||||||
model.load("balloon-detect.pt")
|
model.load("best.pt")
|
||||||
model.info(verbose=True)
|
model.info(verbose=True)
|
||||||
|
|
||||||
|
|
||||||
@ -23,35 +23,35 @@ def test_model_fuse():
|
|||||||
model = YOLO()
|
model = YOLO()
|
||||||
model.new("yolov8n.yaml")
|
model.new("yolov8n.yaml")
|
||||||
model.fuse()
|
model.fuse()
|
||||||
model.load("balloon-detect.pt")
|
model.load("best.pt")
|
||||||
model.fuse()
|
model.fuse()
|
||||||
|
|
||||||
|
|
||||||
def test_visualize_preds():
|
def test_visualize_preds():
|
||||||
model = YOLO()
|
model = YOLO()
|
||||||
model.load("balloon-segment.pt")
|
model.load("best.pt")
|
||||||
model.predict(source="ultralytics/assets")
|
model.predict(source="ultralytics/assets")
|
||||||
|
|
||||||
|
|
||||||
def test_val():
|
def test_val():
|
||||||
model = YOLO()
|
model = YOLO()
|
||||||
model.load("balloon-segment.pt")
|
model.load("best.pt")
|
||||||
model.val(data="coco128-seg.yaml", imgsz=32)
|
model.val(data="coco128.yaml", imgsz=32)
|
||||||
|
|
||||||
|
|
||||||
def test_model_resume():
|
def test_model_resume():
|
||||||
model = YOLO()
|
model = YOLO()
|
||||||
model.new("yolov8n-seg.yaml")
|
model.new("yolov8n.yaml")
|
||||||
model.train(epochs=1, imgsz=32, data="coco128-seg.yaml")
|
model.train(epochs=1, imgsz=32, data="coco128.yaml")
|
||||||
try:
|
try:
|
||||||
model.resume(task="segment")
|
model.resume(task="detect")
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
print("Successfully caught resume assert!")
|
print("Successfully caught resume assert!")
|
||||||
|
|
||||||
|
|
||||||
def test_model_train_pretrained():
|
def test_model_train_pretrained():
|
||||||
model = YOLO()
|
model = YOLO()
|
||||||
model.load("balloon-detect.pt")
|
model.load("best.pt")
|
||||||
model.train(data="coco128.yaml", epochs=1, imgsz=32)
|
model.train(data="coco128.yaml", epochs=1, imgsz=32)
|
||||||
model.new("yolov8n.yaml")
|
model.new("yolov8n.yaml")
|
||||||
model.train(data="coco128.yaml", epochs=1, imgsz=32)
|
model.train(data="coco128.yaml", epochs=1, imgsz=32)
|
||||||
|
@ -43,6 +43,7 @@ class YOLO:
|
|||||||
self.trainer = None
|
self.trainer = None
|
||||||
self.task = None
|
self.task = None
|
||||||
self.ckpt = None
|
self.ckpt = None
|
||||||
|
self.overrides = {}
|
||||||
|
|
||||||
def new(self, cfg: str):
|
def new(self, cfg: str):
|
||||||
"""
|
"""
|
||||||
@ -69,6 +70,10 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
self.ckpt = torch.load(weights, map_location="cpu")
|
self.ckpt = torch.load(weights, map_location="cpu")
|
||||||
self.task = self.ckpt["train_args"]["task"]
|
self.task = self.ckpt["train_args"]["task"]
|
||||||
|
self.overrides = dict(self.ckpt["train_args"])
|
||||||
|
self.overrides["device"] = '' # reset device
|
||||||
|
LOGGER.info("Device has been reset to ''")
|
||||||
|
|
||||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task(
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task(
|
||||||
task=self.task)
|
task=self.task)
|
||||||
self.model = attempt_load_weights(weights)
|
self.model = attempt_load_weights(weights)
|
||||||
@ -107,6 +112,7 @@ class YOLO:
|
|||||||
source (str): Accepts all source types accepted by yolo
|
source (str): Accepts all source types accepted by yolo
|
||||||
**kwargs : Any other args accepted by the predictors. Too see all args check 'configuration' section in the docs
|
**kwargs : Any other args accepted by the predictors. Too see all args check 'configuration' section in the docs
|
||||||
"""
|
"""
|
||||||
|
kwargs.update(self.overrides)
|
||||||
predictor = self.PredictorClass(overrides=kwargs)
|
predictor = self.PredictorClass(overrides=kwargs)
|
||||||
|
|
||||||
# check size type
|
# check size type
|
||||||
@ -119,7 +125,7 @@ class YOLO:
|
|||||||
predictor.setup(model=self.model, source=source)
|
predictor.setup(model=self.model, source=source)
|
||||||
predictor()
|
predictor()
|
||||||
|
|
||||||
def val(self, data, **kwargs):
|
def val(self, data=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Validate a model on a given dataset
|
Validate a model on a given dataset
|
||||||
|
|
||||||
@ -130,8 +136,9 @@ class YOLO:
|
|||||||
if not self.model:
|
if not self.model:
|
||||||
raise Exception("model not initialized!")
|
raise Exception("model not initialized!")
|
||||||
|
|
||||||
|
kwargs.update(self.overrides)
|
||||||
args = get_config(config=DEFAULT_CONFIG, overrides=kwargs)
|
args = get_config(config=DEFAULT_CONFIG, overrides=kwargs)
|
||||||
args.data = data
|
args.data = data or args.data
|
||||||
args.task = self.task
|
args.task = self.task
|
||||||
|
|
||||||
validator = self.ValidatorClass(args=args)
|
validator = self.ValidatorClass(args=args)
|
||||||
|
@ -86,10 +86,15 @@ class BasePredictor:
|
|||||||
|
|
||||||
# data
|
# data
|
||||||
if self.data:
|
if self.data:
|
||||||
|
try:
|
||||||
if self.data.endswith(".yaml"):
|
if self.data.endswith(".yaml"):
|
||||||
self.data = check_dataset_yaml(self.data)
|
self.data = check_dataset_yaml(self.data)
|
||||||
else:
|
else:
|
||||||
self.data = check_dataset(self.data)
|
self.data = check_dataset(self.data)
|
||||||
|
except AssertionError as e:
|
||||||
|
LOGGER.info(f"Error ocurred: {e}")
|
||||||
|
finally:
|
||||||
|
LOGGER.info("Predictor will continue without reading the dataset")
|
||||||
|
|
||||||
# model
|
# model
|
||||||
device = select_device(self.args.device)
|
device = select_device(self.args.device)
|
||||||
|
@ -46,10 +46,15 @@ class BaseTrainer:
|
|||||||
self.validator = None
|
self.validator = None
|
||||||
self.model = None
|
self.model = None
|
||||||
self.callbacks = defaultdict(list)
|
self.callbacks = defaultdict(list)
|
||||||
self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
|
||||||
|
# dirs
|
||||||
|
project = overrides.get("project") or self.args.task
|
||||||
|
name = overrides.get("name") or self.args.mode
|
||||||
|
self.save_dir = increment_path(Path("runs") / project / name, exist_ok=self.args.exist_ok)
|
||||||
self.wdir = self.save_dir / 'weights' # weights dir
|
self.wdir = self.save_dir / 'weights' # weights dir
|
||||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||||
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
|
||||||
|
|
||||||
self.batch_size = self.args.batch_size
|
self.batch_size = self.args.batch_size
|
||||||
self.epochs = self.args.epochs
|
self.epochs = self.args.epochs
|
||||||
self.start_epoch = 0
|
self.start_epoch = 0
|
||||||
|
@ -6,7 +6,7 @@ from omegaconf import DictConfig, OmegaConf
|
|||||||
from ultralytics.yolo.utils.configs.hydra_patch import check_config_mismatch
|
from ultralytics.yolo.utils.configs.hydra_patch import check_config_mismatch
|
||||||
|
|
||||||
|
|
||||||
def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict]):
|
def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):
|
||||||
"""
|
"""
|
||||||
Accepts yaml file name or DictConfig containing experiment configuration.
|
Accepts yaml file name or DictConfig containing experiment configuration.
|
||||||
Returns training args namespace
|
Returns training args namespace
|
||||||
|
Loading…
x
Reference in New Issue
Block a user