mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
Change class depending on dataset in model interface (#77)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
24a7c068ad
commit
48c95ba083
@ -49,6 +49,16 @@ def test_model_resume():
|
|||||||
print("Successfully caught resume assert!")
|
print("Successfully caught resume assert!")
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_train_pretrained():
|
||||||
|
model = YOLO()
|
||||||
|
model.load("balloon-detect.pt")
|
||||||
|
model.train(data="coco128.yaml", epochs=1, img_size=32)
|
||||||
|
model.new("yolov5n.yaml")
|
||||||
|
model.train(data="coco128.yaml", epochs=1, img_size=32)
|
||||||
|
img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512)
|
||||||
|
model(img)
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
test_model_forward()
|
test_model_forward()
|
||||||
test_model_info()
|
test_model_info()
|
||||||
@ -56,6 +66,7 @@ def test():
|
|||||||
test_visualize_preds()
|
test_visualize_preds()
|
||||||
test_val()
|
test_val()
|
||||||
test_model_resume()
|
test_model_resume()
|
||||||
|
test_model_train_pretrained()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
from ultralytics import yolo
|
from ultralytics import yolo
|
||||||
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||||
from ultralytics.yolo.utils import LOGGER
|
from ultralytics.yolo.utils import LOGGER
|
||||||
from ultralytics.yolo.utils.checks import check_yaml
|
from ultralytics.yolo.utils.checks import check_yaml
|
||||||
@ -146,7 +146,7 @@ class YOLO:
|
|||||||
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
|
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
|
||||||
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
|
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
|
||||||
"""
|
"""
|
||||||
if not self.model and not self.ckpt:
|
if not self.model:
|
||||||
raise Exception("model not initialized. Use .new() or .load()")
|
raise Exception("model not initialized. Use .new() or .load()")
|
||||||
|
|
||||||
overrides = kwargs
|
overrides = kwargs
|
||||||
@ -159,8 +159,10 @@ class YOLO:
|
|||||||
raise Exception("dataset not provided! Please check if you have defined `data` in you configs")
|
raise Exception("dataset not provided! Please check if you have defined `data` in you configs")
|
||||||
|
|
||||||
self.trainer = self.TrainerClass(overrides=overrides)
|
self.trainer = self.TrainerClass(overrides=overrides)
|
||||||
# load pre-trained weights if found, else use the loaded model
|
self.trainer.model = self.trainer.load_model(weights=self.ckpt,
|
||||||
self.trainer.model = self.trainer.load_model(weights=self.ckpt) if self.ckpt else self.model
|
model_cfg=self.model.yaml if self.task != "classify" else None)
|
||||||
|
self.model = self.trainer.model # override here to save memory
|
||||||
|
|
||||||
self.trainer.train()
|
self.trainer.train()
|
||||||
|
|
||||||
def resume(self, task=None, model=None):
|
def resume(self, task=None, model=None):
|
||||||
@ -199,6 +201,9 @@ class YOLO:
|
|||||||
|
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
self.model.to(device)
|
||||||
|
|
||||||
def _guess_ops_from_task(self, task):
|
def _guess_ops_from_task(self, task):
|
||||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
|
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
|
||||||
# warning: eval is unsafe. Use with caution
|
# warning: eval is unsafe. Use with caution
|
||||||
|
Loading…
x
Reference in New Issue
Block a user