mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 06:14:55 +08:00
Model enhancement (#75)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
d85b44f259
commit
eb5adf4e0b
@ -1,13 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
from ultralytics.yolo import YOLO
|
from ultralytics.yolo import YOLO
|
||||||
|
|
||||||
|
|
||||||
def test_model():
|
def test_model_forward():
|
||||||
model = YOLO()
|
model = YOLO()
|
||||||
model.new("assets/dummy_model.yaml")
|
model.new("yolov5n-seg.yaml")
|
||||||
model.model = "squeezenet1_0" # temp solution before get_model is implemented
|
img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512)
|
||||||
# model.load("yolov5n.pt")
|
model.forward(img)
|
||||||
model.train(data="imagenette160", epochs=1, lr0=0.01)
|
model(img)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_info():
|
||||||
|
model = YOLO()
|
||||||
|
model.new("yolov5n.yaml")
|
||||||
|
model.info()
|
||||||
|
model.load("balloon-detect.pt")
|
||||||
|
model.info(verbose=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_fuse():
|
||||||
|
model = YOLO()
|
||||||
|
model.new("yolov5n.yaml")
|
||||||
|
model.fuse()
|
||||||
|
model.load("balloon-detect.pt")
|
||||||
|
model.fuse()
|
||||||
|
|
||||||
|
|
||||||
|
def test_visualize_preds():
|
||||||
|
model = YOLO()
|
||||||
|
model.load("balloon-segment.pt")
|
||||||
|
model.predict(source="ultralytics/assets")
|
||||||
|
|
||||||
|
|
||||||
|
def test_val():
|
||||||
|
model = YOLO()
|
||||||
|
model.load("balloon-segment.pt")
|
||||||
|
model.val(data="coco128-seg.yaml", img_size=32)
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_resume():
|
||||||
|
model = YOLO()
|
||||||
|
model.new("yolov5n-seg.yaml")
|
||||||
|
model.train(epochs=1, img_size=32, data="coco128-seg.yaml")
|
||||||
|
try:
|
||||||
|
model.resume(task="segment")
|
||||||
|
except AssertionError:
|
||||||
|
print("Successfully caught resume assert!")
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
test_model_forward()
|
||||||
|
test_model_info()
|
||||||
|
test_model_fuse()
|
||||||
|
test_visualize_preds()
|
||||||
|
test_val()
|
||||||
|
test_model_resume()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_model()
|
test()
|
||||||
|
@ -1,18 +1,28 @@
|
|||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from ultralytics import yolo
|
from ultralytics import yolo
|
||||||
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||||
from ultralytics.yolo.utils import LOGGER
|
from ultralytics.yolo.utils import LOGGER
|
||||||
from ultralytics.yolo.utils.checks import check_yaml
|
from ultralytics.yolo.utils.checks import check_yaml
|
||||||
|
from ultralytics.yolo.utils.configs import get_config
|
||||||
from ultralytics.yolo.utils.files import yaml_load
|
from ultralytics.yolo.utils.files import yaml_load
|
||||||
from ultralytics.yolo.utils.modeling import attempt_load_weights
|
from ultralytics.yolo.utils.modeling import attempt_load_weights
|
||||||
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel
|
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel
|
||||||
|
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||||
|
|
||||||
# map head: [model, trainer]
|
# map head: [model, trainer, validator, predictor]
|
||||||
MODEL_MAP = {
|
MODEL_MAP = {
|
||||||
"classify": [ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer'],
|
"classify": [
|
||||||
"detect": [DetectionModel, 'yolo.TYPE.detect.DetectionTrainer'],
|
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
|
||||||
"segment": [SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer']}
|
'yolo.TYPE.classify.ClassificationPredictor'],
|
||||||
|
"detect": [
|
||||||
|
DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
|
||||||
|
'yolo.TYPE.detect.DetectionPredictor'],
|
||||||
|
"segment": [
|
||||||
|
SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
|
||||||
|
'yolo.TYPE.segment.SegmentationPredictor']}
|
||||||
|
|
||||||
|
|
||||||
class YOLO:
|
class YOLO:
|
||||||
@ -28,6 +38,8 @@ class YOLO:
|
|||||||
self.type = type
|
self.type = type
|
||||||
self.ModelClass = None
|
self.ModelClass = None
|
||||||
self.TrainerClass = None
|
self.TrainerClass = None
|
||||||
|
self.ValidatorClass = None
|
||||||
|
self.PredictorClass = None
|
||||||
self.model = None
|
self.model = None
|
||||||
self.trainer = None
|
self.trainer = None
|
||||||
self.task = None
|
self.task = None
|
||||||
@ -43,7 +55,9 @@ class YOLO:
|
|||||||
cfg = check_yaml(cfg) # check YAML
|
cfg = check_yaml(cfg) # check YAML
|
||||||
with open(cfg, encoding='ascii', errors='ignore') as f:
|
with open(cfg, encoding='ascii', errors='ignore') as f:
|
||||||
cfg = yaml.safe_load(f) # model dict
|
cfg = yaml.safe_load(f) # model dict
|
||||||
self.ModelClass, self.TrainerClass, self.task = self._guess_model_trainer_and_task(cfg["head"][-1][-2])
|
self.task = self._guess_task_from_head(cfg["head"][-1][-2])
|
||||||
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task(
|
||||||
|
self.task)
|
||||||
self.model = self.ModelClass(cfg) # initialize
|
self.model = self.ModelClass(cfg) # initialize
|
||||||
|
|
||||||
def load(self, weights: str):
|
def load(self, weights: str):
|
||||||
@ -56,8 +70,8 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
self.ckpt = torch.load(weights, map_location="cpu")
|
self.ckpt = torch.load(weights, map_location="cpu")
|
||||||
self.task = self.ckpt["train_args"]["task"]
|
self.task = self.ckpt["train_args"]["task"]
|
||||||
_, trainer_class_literal = MODEL_MAP[self.task]
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task(
|
||||||
self.TrainerClass = eval(trainer_class_literal.replace("TYPE", f"v{self.type}"))
|
task=self.task)
|
||||||
self.model = attempt_load_weights(weights)
|
self.model = attempt_load_weights(weights)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@ -70,6 +84,60 @@ class YOLO:
|
|||||||
for p in self.model.parameters():
|
for p in self.model.parameters():
|
||||||
p.requires_grad = True
|
p.requires_grad = True
|
||||||
|
|
||||||
|
def info(self, verbose=False):
|
||||||
|
"""
|
||||||
|
Logs model info
|
||||||
|
|
||||||
|
Args:
|
||||||
|
verbose (bool): Controls verbosity.
|
||||||
|
"""
|
||||||
|
if not self.model:
|
||||||
|
LOGGER.info("model not initialized!")
|
||||||
|
self.model.info(verbose=verbose)
|
||||||
|
|
||||||
|
def fuse(self):
|
||||||
|
if not self.model:
|
||||||
|
LOGGER.info("model not initialized!")
|
||||||
|
self.model.fuse()
|
||||||
|
|
||||||
|
def predict(self, source, **kwargs):
|
||||||
|
"""
|
||||||
|
Visualize prection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source (str): Accepts all source types accepted by yolo
|
||||||
|
**kwargs : Any other args accepted by the predictors. Too see all args check 'configuration' section in the docs
|
||||||
|
"""
|
||||||
|
predictor = self.PredictorClass(overrides=kwargs)
|
||||||
|
|
||||||
|
# check size type
|
||||||
|
sz = predictor.args.img_size
|
||||||
|
if type(sz) != int: # recieved listConfig
|
||||||
|
predictor.args.img_size = [sz[0], sz[0]] if len(sz) == 1 else [sz[0], sz[1]] # expand
|
||||||
|
else:
|
||||||
|
predictor.args.img_size = [sz, sz]
|
||||||
|
|
||||||
|
predictor.setup(model=self.model, source=source)
|
||||||
|
predictor()
|
||||||
|
|
||||||
|
def val(self, data, **kwargs):
|
||||||
|
"""
|
||||||
|
Validate a model on a given dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
||||||
|
kwargs: Any other args accepted by the validators. Too see all args check 'configuration' section in the docs
|
||||||
|
"""
|
||||||
|
if not self.model:
|
||||||
|
raise Exception("model not initialized!")
|
||||||
|
|
||||||
|
args = get_config(config=DEFAULT_CONFIG, overrides=kwargs)
|
||||||
|
args.data = data
|
||||||
|
args.task = self.task
|
||||||
|
|
||||||
|
validator = self.ValidatorClass(args=args)
|
||||||
|
validator(model=self.model)
|
||||||
|
|
||||||
def train(self, **kwargs):
|
def train(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Trains the model on given dataset.
|
Trains the model on given dataset.
|
||||||
@ -95,22 +163,28 @@ class YOLO:
|
|||||||
self.trainer.model = self.trainer.load_model(weights=self.ckpt) if self.ckpt else self.model
|
self.trainer.model = self.trainer.load_model(weights=self.ckpt) if self.ckpt else self.model
|
||||||
self.trainer.train()
|
self.trainer.train()
|
||||||
|
|
||||||
def resume(self, task, model=None):
|
def resume(self, task=None, model=None):
|
||||||
"""
|
"""
|
||||||
Resume a training task.
|
Resume a training task. Requires either `task` or `model`. `model` takes the higher precederence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task (str): The task type you want to resume. Automatically finds the last run to resume if `model` is not specified.
|
task (str): The task type you want to resume. Automatically finds the last run to resume if `model` is not specified.
|
||||||
model (str): [Optional] The model checkpoint to resume from. If not found, the last run of the given task type is resumed.
|
model (str): The model checkpoint to resume from. If not found, the last run of the given task type is resumed.
|
||||||
|
If `model` is speficied
|
||||||
"""
|
"""
|
||||||
if task.lower() not in MODEL_MAP:
|
if task:
|
||||||
raise Exception(f"unrecognised task - {task}. Supported tasks are {MODEL_MAP.keys()}")
|
if task.lower() not in MODEL_MAP:
|
||||||
_, trainer_class_literal = MODEL_MAP[task.lower()]
|
raise Exception(f"unrecognised task - {task}. Supported tasks are {MODEL_MAP.keys()}")
|
||||||
self.TrainerClass = eval(trainer_class_literal.replace("TYPE", f"v{self.type}"))
|
else:
|
||||||
|
ckpt = torch.load(model, map_location="cpu")
|
||||||
|
task = ckpt["train_args"]["task"]
|
||||||
|
del ckpt
|
||||||
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task(
|
||||||
|
task=task.lower())
|
||||||
self.trainer = self.TrainerClass(overrides={"task": task.lower(), "resume": model if model else True})
|
self.trainer = self.TrainerClass(overrides={"task": task.lower(), "resume": model if model else True})
|
||||||
self.trainer.train()
|
self.trainer.train()
|
||||||
|
|
||||||
def _guess_model_trainer_and_task(self, head):
|
@staticmethod
|
||||||
|
def _guess_task_from_head(head):
|
||||||
task = None
|
task = None
|
||||||
if head.lower() in ["classify", "classifier", "cls", "fc"]:
|
if head.lower() in ["classify", "classifier", "cls", "fc"]:
|
||||||
task = "classify"
|
task = "classify"
|
||||||
@ -118,13 +192,27 @@ class YOLO:
|
|||||||
task = "detect"
|
task = "detect"
|
||||||
if head.lower() in ["segment"]:
|
if head.lower() in ["segment"]:
|
||||||
task = "segment"
|
task = "segment"
|
||||||
model_class, trainer_class = MODEL_MAP[task]
|
|
||||||
|
if not task:
|
||||||
|
raise Exception(
|
||||||
|
"task or model not recognized! Please refer the docs at : ") # TODO: add gitHub and docs links
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
def _guess_ops_from_task(self, task):
|
||||||
|
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
|
||||||
# warning: eval is unsafe. Use with caution
|
# warning: eval is unsafe. Use with caution
|
||||||
trainer_class = eval(trainer_class.replace("TYPE", f"{self.type}"))
|
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
|
||||||
|
validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
|
||||||
|
predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))
|
||||||
|
|
||||||
return model_class, trainer_class, task
|
return model_class, trainer_class, validator_class, predictor_class
|
||||||
|
|
||||||
|
@smart_inference_mode()
|
||||||
def __call__(self, imgs):
|
def __call__(self, imgs):
|
||||||
if not self.model:
|
if not self.model:
|
||||||
LOGGER.info("model not initialized!")
|
LOGGER.info("model not initialized!")
|
||||||
return self.model(imgs)
|
return self.model(imgs)
|
||||||
|
|
||||||
|
def forward(self, imgs):
|
||||||
|
return self.__call__(imgs)
|
||||||
|
@ -37,15 +37,23 @@ class AutoBackend(nn.Module):
|
|||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
w = str(weights[0] if isinstance(weights, list) else weights)
|
w = str(weights[0] if isinstance(weights, list) else weights)
|
||||||
|
nn_module = isinstance(weights, torch.nn.Module)
|
||||||
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
|
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
|
||||||
fp16 &= pt or jit or onnx or engine # FP16
|
fp16 &= pt or jit or onnx or engine or nn_module # FP16
|
||||||
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
|
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
|
||||||
stride = 32 # default stride
|
stride = 32 # default stride
|
||||||
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
||||||
if not (pt or triton):
|
if not (pt or triton or nn_module):
|
||||||
w = attempt_download(w) # download if not local
|
w = attempt_download(w) # download if not local
|
||||||
|
|
||||||
if pt: # PyTorch
|
# NOTE: special case: in-memory pytorch model
|
||||||
|
if nn_module:
|
||||||
|
model = weights.to(device)
|
||||||
|
model = model.fuse() if fuse else model
|
||||||
|
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
||||||
|
model.half() if fp16 else model.float()
|
||||||
|
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
||||||
|
elif pt: # PyTorch
|
||||||
model = attempt_load_weights(weights if isinstance(weights, list) else w,
|
model = attempt_load_weights(weights if isinstance(weights, list) else w,
|
||||||
device=device,
|
device=device,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
@ -215,7 +223,7 @@ class AutoBackend(nn.Module):
|
|||||||
if self.nhwc:
|
if self.nhwc:
|
||||||
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
|
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
|
||||||
|
|
||||||
if self.pt: # PyTorch
|
if self.pt or self.nn_module: # PyTorch
|
||||||
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
|
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
|
||||||
elif self.jit: # TorchScript
|
elif self.jit: # TorchScript
|
||||||
y = self.model(im)
|
y = self.model(im)
|
||||||
@ -294,7 +302,7 @@ class AutoBackend(nn.Module):
|
|||||||
|
|
||||||
def warmup(self, imgsz=(1, 3, 640, 640)):
|
def warmup(self, imgsz=(1, 3, 640, 640)):
|
||||||
# Warmup model by running inference once
|
# Warmup model by running inference once
|
||||||
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton
|
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
|
||||||
if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
|
if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
|
||||||
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
|
||||||
for _ in range(2 if self.jit else 1): #
|
for _ in range(2 if self.jit else 1): #
|
||||||
@ -306,7 +314,7 @@ class AutoBackend(nn.Module):
|
|||||||
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
|
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
|
||||||
from ultralytics.yolo.engine.exporter import export_formats
|
from ultralytics.yolo.engine.exporter import export_formats
|
||||||
sf = list(export_formats().Suffix) # export suffixes
|
sf = list(export_formats().Suffix) # export suffixes
|
||||||
if not is_url(p, check=False):
|
if not is_url(p, check=False) and not isinstance(p, str):
|
||||||
check_suffix(p, sf) # checks
|
check_suffix(p, sf) # checks
|
||||||
url = urlparse(p) # if url may be Triton inference server
|
url = urlparse(p) # if url may be Triton inference server
|
||||||
types = [s in Path(p).name for s in sf]
|
types = [s in Path(p).name for s in sf]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user