mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Fix yolo mode=train
CLI bug on model load (#133)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c3d961fb03
commit
8f3cd52844
59
tests/tests_cli.py
Normal file
59
tests/tests_cli.py
Normal file
@ -0,0 +1,59 @@
|
||||
import os
|
||||
|
||||
from ultralytics.yolo.utils import ROOT
|
||||
|
||||
|
||||
def test_checks():
|
||||
os.system('yolo mode=checks')
|
||||
|
||||
|
||||
# Train checks ---------------------------------------------------------------------------------------------------------
|
||||
def test_train_detect():
|
||||
os.system('yolo mode=train task=detect model=yolov8n.yaml data=coco128.yaml imgsz=32 epochs=1')
|
||||
|
||||
|
||||
def test_train_segment():
|
||||
os.system('yolo mode=train task=segment model=yolov8n-seg.yaml data=coco128-seg.yaml imgsz=32 epochs=1')
|
||||
|
||||
|
||||
def test_train_classify():
|
||||
pass
|
||||
|
||||
|
||||
# Val checks -----------------------------------------------------------------------------------------------------------
|
||||
def test_val_detect():
|
||||
os.system('yolo mode=val task=detect model=yolov8n.pt data=coco128.yaml imgsz=32 epochs=1')
|
||||
|
||||
|
||||
def test_val_segment():
|
||||
pass
|
||||
|
||||
|
||||
def test_val_classify():
|
||||
pass
|
||||
|
||||
|
||||
# Predict checks -------------------------------------------------------------------------------------------------------
|
||||
def test_predict_detect():
|
||||
os.system(f"yolo mode=predict model=yolov8n.pt source={ROOT / 'assets'}")
|
||||
|
||||
|
||||
def test_predict_segment():
|
||||
pass
|
||||
|
||||
|
||||
def test_predict_classify():
|
||||
pass
|
||||
|
||||
|
||||
# Export checks --------------------------------------------------------------------------------------------------------
|
||||
def test_export_detect_torchscript():
|
||||
os.system('yolo mode=export model=yolov8n.pt format=torchscript')
|
||||
|
||||
|
||||
def test_export_segment_torchscript():
|
||||
pass
|
||||
|
||||
|
||||
def test_export_classify_torchscript():
|
||||
pass
|
@ -5,11 +5,10 @@ import time
|
||||
import requests
|
||||
|
||||
from ultralytics.hub.config import HUB_API_ROOT
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, colorstr, emojis, yaml_load
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, LOGGER, RANK, SETTINGS, colorstr, emojis, yaml_load
|
||||
|
||||
PREFIX = colorstr('Ultralytics: ')
|
||||
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
||||
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG)
|
||||
|
||||
|
||||
def check_dataset_disk_space(url='https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip', sf=2.0):
|
||||
|
@ -10,13 +10,11 @@ import torchvision
|
||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||
GhostBottleneck, GhostConv, Segment)
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, yaml_load
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER, colorstr, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
|
||||
model_info, scale_img, time_sync)
|
||||
|
||||
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG, append_filename=False)
|
||||
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
'''
|
||||
@ -286,16 +284,15 @@ class ClassificationModel(BaseModel):
|
||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||
from ultralytics.yolo.utils.downloads import attempt_download
|
||||
default_keys = DEFAULT_CONFIG_DICT.keys()
|
||||
|
||||
model = Ensemble()
|
||||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
||||
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}
|
||||
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
|
||||
# Model compatibility updates
|
||||
ckpt.args = {k: v for k, v in args.items() if k in default_keys}
|
||||
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS}
|
||||
|
||||
# Append
|
||||
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
||||
|
@ -362,7 +362,7 @@ class BaseTrainer:
|
||||
return
|
||||
# We should improve the code flow here. This function looks hacky
|
||||
model = self.model
|
||||
pretrained = not (str(model).endswith(".yaml"))
|
||||
pretrained = not str(model).endswith(".yaml")
|
||||
# config
|
||||
if not pretrained:
|
||||
model = check_file(model)
|
||||
|
@ -63,6 +63,11 @@ pd.options.display.max_columns = 10
|
||||
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
||||
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
||||
|
||||
# Default config dictionary
|
||||
with open(DEFAULT_CONFIG, errors='ignore') as f:
|
||||
DEFAULT_CONFIG_DICT = yaml.safe_load(f)
|
||||
DEFAULT_CONFIG_KEYS = DEFAULT_CONFIG_DICT.keys()
|
||||
|
||||
|
||||
def is_colab():
|
||||
"""
|
||||
|
@ -16,7 +16,7 @@ import torch.nn.functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import ultralytics
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER
|
||||
from ultralytics.yolo.utils.checks import git_describe
|
||||
|
||||
from .checks import check_version
|
||||
@ -270,6 +270,7 @@ class ModelEMA:
|
||||
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
|
||||
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
||||
x = torch.load(f, map_location=torch.device('cpu'))
|
||||
args = {**DEFAULT_CONFIG_DICT, **x['train_args']} # combine model args with default args, preferring model args
|
||||
if x.get('ema'):
|
||||
x['model'] = x['ema'] # replace model with ema
|
||||
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
|
||||
@ -278,6 +279,7 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op
|
||||
x['model'].half() # to FP16
|
||||
for p in x['model'].parameters():
|
||||
p.requires_grad = False
|
||||
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # strip non-default keys
|
||||
torch.save(x, s or f)
|
||||
mb = os.path.getsize(s or f) / 1E6 # filesize
|
||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||
|
@ -54,9 +54,12 @@ class DetectionTrainer(BaseTrainer):
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
||||
model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
|
||||
model = DetectionModel(model_cfg or getattr(weights, 'yaml', None) or weights['model'].yaml,
|
||||
ch=3,
|
||||
nc=self.data["nc"],
|
||||
verbose=verbose)
|
||||
if weights:
|
||||
model.load(weights, verbose)
|
||||
model.load(weights['model'] if isinstance(weights, dict) else weights, verbose)
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user