mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 05:55:51 +08:00
Fix yolo mode=train
CLI bug on model load (#133)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c3d961fb03
commit
8f3cd52844
59
tests/tests_cli.py
Normal file
59
tests/tests_cli.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from ultralytics.yolo.utils import ROOT
|
||||||
|
|
||||||
|
|
||||||
|
def test_checks():
|
||||||
|
os.system('yolo mode=checks')
|
||||||
|
|
||||||
|
|
||||||
|
# Train checks ---------------------------------------------------------------------------------------------------------
|
||||||
|
def test_train_detect():
|
||||||
|
os.system('yolo mode=train task=detect model=yolov8n.yaml data=coco128.yaml imgsz=32 epochs=1')
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_segment():
|
||||||
|
os.system('yolo mode=train task=segment model=yolov8n-seg.yaml data=coco128-seg.yaml imgsz=32 epochs=1')
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_classify():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Val checks -----------------------------------------------------------------------------------------------------------
|
||||||
|
def test_val_detect():
|
||||||
|
os.system('yolo mode=val task=detect model=yolov8n.pt data=coco128.yaml imgsz=32 epochs=1')
|
||||||
|
|
||||||
|
|
||||||
|
def test_val_segment():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_val_classify():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Predict checks -------------------------------------------------------------------------------------------------------
|
||||||
|
def test_predict_detect():
|
||||||
|
os.system(f"yolo mode=predict model=yolov8n.pt source={ROOT / 'assets'}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_predict_segment():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_predict_classify():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Export checks --------------------------------------------------------------------------------------------------------
|
||||||
|
def test_export_detect_torchscript():
|
||||||
|
os.system('yolo mode=export model=yolov8n.pt format=torchscript')
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_segment_torchscript():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_classify_torchscript():
|
||||||
|
pass
|
@ -5,11 +5,10 @@ import time
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from ultralytics.hub.config import HUB_API_ROOT
|
from ultralytics.hub.config import HUB_API_ROOT
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, colorstr, emojis, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, LOGGER, RANK, SETTINGS, colorstr, emojis, yaml_load
|
||||||
|
|
||||||
PREFIX = colorstr('Ultralytics: ')
|
PREFIX = colorstr('Ultralytics: ')
|
||||||
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
|
||||||
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG)
|
|
||||||
|
|
||||||
|
|
||||||
def check_dataset_disk_space(url='https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip', sf=2.0):
|
def check_dataset_disk_space(url='https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip', sf=2.0):
|
||||||
|
@ -10,13 +10,11 @@ import torchvision
|
|||||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
||||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||||
GhostBottleneck, GhostConv, Segment)
|
GhostBottleneck, GhostConv, Segment)
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER, colorstr, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_yaml
|
from ultralytics.yolo.utils.checks import check_yaml
|
||||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
|
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
|
||||||
model_info, scale_img, time_sync)
|
model_info, scale_img, time_sync)
|
||||||
|
|
||||||
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG, append_filename=False)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(nn.Module):
|
class BaseModel(nn.Module):
|
||||||
'''
|
'''
|
||||||
@ -286,16 +284,15 @@ class ClassificationModel(BaseModel):
|
|||||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||||
from ultralytics.yolo.utils.downloads import attempt_download
|
from ultralytics.yolo.utils.downloads import attempt_download
|
||||||
default_keys = DEFAULT_CONFIG_DICT.keys()
|
|
||||||
|
|
||||||
model = Ensemble()
|
model = Ensemble()
|
||||||
for w in weights if isinstance(weights, list) else [weights]:
|
for w in weights if isinstance(weights, list) else [weights]:
|
||||||
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
||||||
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}
|
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||||
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||||
|
|
||||||
# Model compatibility updates
|
# Model compatibility updates
|
||||||
ckpt.args = {k: v for k, v in args.items() if k in default_keys}
|
ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS}
|
||||||
|
|
||||||
# Append
|
# Append
|
||||||
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
||||||
|
@ -362,7 +362,7 @@ class BaseTrainer:
|
|||||||
return
|
return
|
||||||
# We should improve the code flow here. This function looks hacky
|
# We should improve the code flow here. This function looks hacky
|
||||||
model = self.model
|
model = self.model
|
||||||
pretrained = not (str(model).endswith(".yaml"))
|
pretrained = not str(model).endswith(".yaml")
|
||||||
# config
|
# config
|
||||||
if not pretrained:
|
if not pretrained:
|
||||||
model = check_file(model)
|
model = check_file(model)
|
||||||
|
@ -63,6 +63,11 @@ pd.options.display.max_columns = 10
|
|||||||
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
||||||
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
||||||
|
|
||||||
|
# Default config dictionary
|
||||||
|
with open(DEFAULT_CONFIG, errors='ignore') as f:
|
||||||
|
DEFAULT_CONFIG_DICT = yaml.safe_load(f)
|
||||||
|
DEFAULT_CONFIG_KEYS = DEFAULT_CONFIG_DICT.keys()
|
||||||
|
|
||||||
|
|
||||||
def is_colab():
|
def is_colab():
|
||||||
"""
|
"""
|
||||||
|
@ -16,7 +16,7 @@ import torch.nn.functional as F
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
import ultralytics
|
import ultralytics
|
||||||
from ultralytics.yolo.utils import LOGGER
|
from ultralytics.yolo.utils import DEFAULT_CONFIG_DICT, DEFAULT_CONFIG_KEYS, LOGGER
|
||||||
from ultralytics.yolo.utils.checks import git_describe
|
from ultralytics.yolo.utils.checks import git_describe
|
||||||
|
|
||||||
from .checks import check_version
|
from .checks import check_version
|
||||||
@ -270,6 +270,7 @@ class ModelEMA:
|
|||||||
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
|
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
|
||||||
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
||||||
x = torch.load(f, map_location=torch.device('cpu'))
|
x = torch.load(f, map_location=torch.device('cpu'))
|
||||||
|
args = {**DEFAULT_CONFIG_DICT, **x['train_args']} # combine model args with default args, preferring model args
|
||||||
if x.get('ema'):
|
if x.get('ema'):
|
||||||
x['model'] = x['ema'] # replace model with ema
|
x['model'] = x['ema'] # replace model with ema
|
||||||
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
|
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
|
||||||
@ -278,6 +279,7 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op
|
|||||||
x['model'].half() # to FP16
|
x['model'].half() # to FP16
|
||||||
for p in x['model'].parameters():
|
for p in x['model'].parameters():
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
|
x['train_args'] = {k: v for k, v in args.items() if k in DEFAULT_CONFIG_KEYS} # strip non-default keys
|
||||||
torch.save(x, s or f)
|
torch.save(x, s or f)
|
||||||
mb = os.path.getsize(s or f) / 1E6 # filesize
|
mb = os.path.getsize(s or f) / 1E6 # filesize
|
||||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||||
|
@ -54,9 +54,12 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
self.model.names = self.data["names"]
|
self.model.names = self.data["names"]
|
||||||
|
|
||||||
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
||||||
model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
|
model = DetectionModel(model_cfg or getattr(weights, 'yaml', None) or weights['model'].yaml,
|
||||||
|
ch=3,
|
||||||
|
nc=self.data["nc"],
|
||||||
|
verbose=verbose)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights, verbose)
|
model.load(weights['model'] if isinstance(weights, dict) else weights, verbose)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user