mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Change class depending on dataset in model interface (#77)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
24a7c068ad
commit
48c95ba083
@ -49,6 +49,16 @@ def test_model_resume():
|
||||
print("Successfully caught resume assert!")
|
||||
|
||||
|
||||
def test_model_train_pretrained():
|
||||
model = YOLO()
|
||||
model.load("balloon-detect.pt")
|
||||
model.train(data="coco128.yaml", epochs=1, img_size=32)
|
||||
model.new("yolov5n.yaml")
|
||||
model.train(data="coco128.yaml", epochs=1, img_size=32)
|
||||
img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512)
|
||||
model(img)
|
||||
|
||||
|
||||
def test():
|
||||
test_model_forward()
|
||||
test_model_info()
|
||||
@ -56,6 +66,7 @@ def test():
|
||||
test_visualize_preds()
|
||||
test_val()
|
||||
test_model_resume()
|
||||
test_model_train_pretrained()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ultralytics import yolo
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
@ -146,7 +146,7 @@ class YOLO:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
|
||||
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
|
||||
"""
|
||||
if not self.model and not self.ckpt:
|
||||
if not self.model:
|
||||
raise Exception("model not initialized. Use .new() or .load()")
|
||||
|
||||
overrides = kwargs
|
||||
@ -159,8 +159,10 @@ class YOLO:
|
||||
raise Exception("dataset not provided! Please check if you have defined `data` in you configs")
|
||||
|
||||
self.trainer = self.TrainerClass(overrides=overrides)
|
||||
# load pre-trained weights if found, else use the loaded model
|
||||
self.trainer.model = self.trainer.load_model(weights=self.ckpt) if self.ckpt else self.model
|
||||
self.trainer.model = self.trainer.load_model(weights=self.ckpt,
|
||||
model_cfg=self.model.yaml if self.task != "classify" else None)
|
||||
self.model = self.trainer.model # override here to save memory
|
||||
|
||||
self.trainer.train()
|
||||
|
||||
def resume(self, task=None, model=None):
|
||||
@ -199,6 +201,9 @@ class YOLO:
|
||||
|
||||
return task
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
|
||||
def _guess_ops_from_task(self, task):
|
||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
|
||||
# warning: eval is unsafe. Use with caution
|
||||
|
Loading…
x
Reference in New Issue
Block a user