mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Unified model loading with backwards compatibility (#132)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
8996c5c6cf
commit
c3d961fb03
@ -42,9 +42,9 @@ Ultralytics YOLO comes with pythonic Model and Trainer interface.
|
|||||||
import ultralytics
|
import ultralytics
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
|
||||||
model = YOLO("s-seg.yaml") # automatically detects task type
|
model = YOLO("yolov8n-seg.yaml") # automatically detects task type
|
||||||
model = YOLO("s-seg.pt") # load checkpoint
|
model = YOLO("yolov8n.pt") # load checkpoint
|
||||||
model.train(data="coco128-segments", epochs=1, lr0=0.01, ...)
|
model.train(data="coco128-seg.yaml", epochs=1, lr0=0.01, ...)
|
||||||
model.train(data="coco128-segments", epochs=1, lr0=0.01, device="0,1,2,3") # DDP mode
|
model.train(data="coco128-seg.yaml", epochs=1, lr0=0.01, device="0,1,2,3") # DDP mode
|
||||||
```
|
```
|
||||||
[API Guide](sdk.md){ .md-button .md-button--primary}
|
[API Guide](sdk.md){ .md-button .md-button--primary}
|
||||||
|
@ -1,11 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
from ultralytics.yolo.utils import ROOT
|
||||||
|
|
||||||
def test_model_init():
|
|
||||||
model = YOLO("yolov8n.yaml")
|
|
||||||
model.info()
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_forward():
|
def test_model_forward():
|
||||||
@ -29,9 +25,9 @@ def test_model_fuse():
|
|||||||
model.fuse()
|
model.fuse()
|
||||||
|
|
||||||
|
|
||||||
def test_visualize_preds():
|
def test_predict_dir():
|
||||||
model = YOLO("yolov8n.pt")
|
model = YOLO("yolov8n.pt")
|
||||||
model.predict(source="ultralytics/assets")
|
model.predict(source=ROOT / "assets")
|
||||||
|
|
||||||
|
|
||||||
def test_val():
|
def test_val():
|
||||||
@ -39,7 +35,7 @@ def test_val():
|
|||||||
model.val(data="coco128.yaml", imgsz=32)
|
model.val(data="coco128.yaml", imgsz=32)
|
||||||
|
|
||||||
|
|
||||||
def test_model_resume():
|
def test_train_resume():
|
||||||
model = YOLO("yolov8n.yaml")
|
model = YOLO("yolov8n.yaml")
|
||||||
model.train(epochs=1, imgsz=32, data="coco128.yaml")
|
model.train(epochs=1, imgsz=32, data="coco128.yaml")
|
||||||
try:
|
try:
|
||||||
@ -48,16 +44,21 @@ def test_model_resume():
|
|||||||
print("Successfully caught resume assert!")
|
print("Successfully caught resume assert!")
|
||||||
|
|
||||||
|
|
||||||
def test_model_train_pretrained():
|
def test_train_scratch():
|
||||||
model = YOLO("yolov8n.pt")
|
|
||||||
model.train(data="coco128.yaml", epochs=1, imgsz=32)
|
|
||||||
model = YOLO("yolov8n.yaml")
|
model = YOLO("yolov8n.yaml")
|
||||||
model.train(data="coco128.yaml", epochs=1, imgsz=32)
|
model.train(data="coco128.yaml", epochs=1, imgsz=32)
|
||||||
img = torch.rand(1, 3, 320, 320)
|
img = torch.rand(1, 3, 320, 320)
|
||||||
model(img)
|
model(img)
|
||||||
|
|
||||||
|
|
||||||
def test_exports():
|
def test_train_pretrained():
|
||||||
|
model = YOLO("yolov8n.pt")
|
||||||
|
model.train(data="coco128.yaml", epochs=1, imgsz=32)
|
||||||
|
img = torch.rand(1, 3, 320, 320)
|
||||||
|
model(img)
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_torchscript():
|
||||||
"""
|
"""
|
||||||
Format Argument Suffix CPU GPU
|
Format Argument Suffix CPU GPU
|
||||||
0 PyTorch - .pt True True
|
0 PyTorch - .pt True True
|
||||||
@ -74,26 +75,35 @@ def test_exports():
|
|||||||
11 PaddlePaddle paddle _paddle_model True True
|
11 PaddlePaddle paddle _paddle_model True True
|
||||||
"""
|
"""
|
||||||
from ultralytics.yolo.engine.exporter import export_formats
|
from ultralytics.yolo.engine.exporter import export_formats
|
||||||
|
|
||||||
print(export_formats())
|
print(export_formats())
|
||||||
|
|
||||||
model = YOLO("yolov8n.yaml")
|
model = YOLO("yolov8n.yaml")
|
||||||
model.export(format='torchscript')
|
model.export(format='torchscript')
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_onnx():
|
||||||
|
model = YOLO("yolov8n.yaml")
|
||||||
model.export(format='onnx')
|
model.export(format='onnx')
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_openvino():
|
||||||
|
model = YOLO("yolov8n.yaml")
|
||||||
model.export(format='openvino')
|
model.export(format='openvino')
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_coreml():
|
||||||
|
model = YOLO("yolov8n.yaml")
|
||||||
model.export(format='coreml')
|
model.export(format='coreml')
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_paddle():
|
||||||
|
model = YOLO("yolov8n.yaml")
|
||||||
model.export(format='paddle')
|
model.export(format='paddle')
|
||||||
|
|
||||||
|
|
||||||
def test():
|
# def run_all_tests(): # do not name function test_...
|
||||||
test_model_forward()
|
# pass
|
||||||
test_model_info()
|
#
|
||||||
test_model_fuse()
|
#
|
||||||
test_visualize_preds()
|
# if __name__ == "__main__":
|
||||||
test_val()
|
# run_all_tests()
|
||||||
test_model_resume()
|
|
||||||
test_model_train_pretrained()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test()
|
|
||||||
|
@ -124,7 +124,7 @@ def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method="post
|
|||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def sync_analytics(cfg, all_keys=False, enabled=True):
|
def sync_analytics(cfg, all_keys=False, enabled=False):
|
||||||
"""
|
"""
|
||||||
Sync analytics data if enabled in the global settings
|
Sync analytics data if enabled in the global settings
|
||||||
|
|
||||||
|
@ -10,11 +10,13 @@ import torchvision
|
|||||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
||||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||||
GhostBottleneck, GhostConv, Segment)
|
GhostBottleneck, GhostConv, Segment)
|
||||||
from ultralytics.yolo.utils import LOGGER, colorstr, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, colorstr, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_yaml
|
from ultralytics.yolo.utils.checks import check_yaml
|
||||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
|
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
|
||||||
model_info, scale_img, time_sync)
|
model_info, scale_img, time_sync)
|
||||||
|
|
||||||
|
DEFAULT_CONFIG_DICT = yaml_load(DEFAULT_CONFIG, append_filename=False)
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(nn.Module):
|
class BaseModel(nn.Module):
|
||||||
'''
|
'''
|
||||||
@ -211,7 +213,7 @@ class DetectionModel(BaseModel):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
def load(self, weights, verbose=True):
|
def load(self, weights, verbose=True):
|
||||||
csd = weights['model'].float().state_dict() # checkpoint state_dict as FP32
|
csd = weights.float().state_dict() # checkpoint state_dict as FP32
|
||||||
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
||||||
self.load_state_dict(csd, strict=False) # load
|
self.load_state_dict(csd, strict=False) # load
|
||||||
if verbose:
|
if verbose:
|
||||||
@ -281,21 +283,21 @@ class ClassificationModel(BaseModel):
|
|||||||
# Functions ------------------------------------------------------------------------------------------------------------
|
# Functions ------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=True):
|
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||||
from ultralytics.yolo.utils.downloads import attempt_download
|
from ultralytics.yolo.utils.downloads import attempt_download
|
||||||
|
default_keys = DEFAULT_CONFIG_DICT.keys()
|
||||||
|
|
||||||
model = Ensemble()
|
model = Ensemble()
|
||||||
for w in weights if isinstance(weights, list) else [weights]:
|
for w in weights if isinstance(weights, list) else [weights]:
|
||||||
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
ckpt = torch.load(attempt_download(w), map_location='cpu') # load
|
||||||
|
args = {**DEFAULT_CONFIG_DICT, **ckpt['train_args']}
|
||||||
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||||
|
|
||||||
# Model compatibility updates
|
# Model compatibility updates
|
||||||
if not hasattr(ckpt, 'stride'):
|
ckpt.args = {k: v for k, v in args.items() if k in default_keys}
|
||||||
ckpt.stride = torch.tensor([32.])
|
|
||||||
if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
|
|
||||||
ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
|
|
||||||
|
|
||||||
|
# Append
|
||||||
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
|
||||||
|
|
||||||
# Module compatibility updates
|
# Module compatibility updates
|
||||||
@ -310,7 +312,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=True):
|
|||||||
if len(model) == 1:
|
if len(model) == 1:
|
||||||
return model[-1]
|
return model[-1]
|
||||||
|
|
||||||
# Return detection ensemble
|
# Return ensemble
|
||||||
print(f'Ensemble created with {weights}\n')
|
print(f'Ensemble created with {weights}\n')
|
||||||
for k in 'names', 'nc', 'yaml':
|
for k in 'names', 'nc', 'yaml':
|
||||||
setattr(model, k, getattr(model[0], k))
|
setattr(model, k, getattr(model[0], k))
|
||||||
|
@ -164,8 +164,8 @@ class Exporter:
|
|||||||
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
|
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
|
||||||
|
|
||||||
# Checks
|
# Checks
|
||||||
if self.args.batch_size == 16:
|
# if self.args.batch_size == model.args['batch_size']: # user has not modified training batch_size
|
||||||
self.args.batch_size = 1 # TODO: resolve batch_size 16 default in config.yaml
|
self.args.batch_size = 1
|
||||||
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
||||||
if self.args.optimize:
|
if self.args.optimize:
|
||||||
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
|
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
|
||||||
@ -778,7 +778,7 @@ def export(cfg):
|
|||||||
if Path(cfg.model).suffix == '.yaml':
|
if Path(cfg.model).suffix == '.yaml':
|
||||||
model = DetectionModel(cfg.model)
|
model = DetectionModel(cfg.model)
|
||||||
elif Path(cfg.model).suffix == '.pt':
|
elif Path(cfg.model).suffix == '.pt':
|
||||||
model = attempt_load_weights(cfg.model)
|
model = attempt_load_weights(cfg.model, fuse=True)
|
||||||
else:
|
else:
|
||||||
TypeError(f'Unsupported model type {cfg.model}')
|
TypeError(f'Unsupported model type {cfg.model}')
|
||||||
exporter(model=model)
|
exporter(model=model)
|
||||||
|
@ -77,13 +77,12 @@ class YOLO:
|
|||||||
Args:
|
Args:
|
||||||
weights (str): model checkpoint to be loaded
|
weights (str): model checkpoint to be loaded
|
||||||
"""
|
"""
|
||||||
self.ckpt = torch.load(weights, map_location="cpu")
|
self.model = attempt_load_weights(weights)
|
||||||
self.task = self.ckpt["train_args"]["task"]
|
self.task = self.model.args["task"]
|
||||||
self.overrides = dict(self.ckpt["train_args"])
|
self.overrides = self.model.args
|
||||||
self.overrides["device"] = '' # reset device
|
self.overrides["device"] = '' # reset device
|
||||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
||||||
self._guess_ops_from_task(self.task)
|
self._guess_ops_from_task(self.task)
|
||||||
self.model = attempt_load_weights(weights, fuse=False)
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
@ -189,7 +188,7 @@ class YOLO:
|
|||||||
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
|
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
|
||||||
|
|
||||||
self.trainer = self.TrainerClass(overrides=overrides)
|
self.trainer = self.TrainerClass(overrides=overrides)
|
||||||
self.trainer.model = self.trainer.load_model(weights=self.ckpt,
|
self.trainer.model = self.trainer.load_model(weights=self.model,
|
||||||
model_cfg=self.model.yaml if self.task != "classify" else None)
|
model_cfg=self.model.yaml if self.task != "classify" else None)
|
||||||
self.model = self.trainer.model # override here to save memory
|
self.model = self.trainer.model # override here to save memory
|
||||||
|
|
||||||
|
@ -106,6 +106,9 @@ class BaseValidator:
|
|||||||
data = check_dataset_yaml(self.args.data)
|
data = check_dataset_yaml(self.args.data)
|
||||||
else:
|
else:
|
||||||
data = check_dataset(self.args.data)
|
data = check_dataset(self.args.data)
|
||||||
|
|
||||||
|
if self.device.type == 'cpu':
|
||||||
|
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||||
self.dataloader = self.dataloader or \
|
self.dataloader = self.dataloader or \
|
||||||
self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size)
|
self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size)
|
||||||
|
|
||||||
|
@ -271,19 +271,20 @@ def yaml_save(file='data.yaml', data=None):
|
|||||||
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
||||||
|
|
||||||
|
|
||||||
def yaml_load(file='data.yaml'):
|
def yaml_load(file='data.yaml', append_filename=True):
|
||||||
"""
|
"""
|
||||||
Load YAML data from a file.
|
Load YAML data from a file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (str, optional): File name. Default is 'data.yaml'.
|
file (str, optional): File name. Default is 'data.yaml'.
|
||||||
|
append_filename (bool): Add the YAML filename to the YAML dictionary. Default is True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: YAML data and file name.
|
dict: YAML data and file name.
|
||||||
"""
|
"""
|
||||||
with open(file, errors='ignore') as f:
|
with open(file, errors='ignore') as f:
|
||||||
# Add YAML filename to dict and return
|
# Add YAML filename to dict and return
|
||||||
return {**yaml.safe_load(f), 'yaml_file': str(file)}
|
return {**yaml.safe_load(f), 'yaml_file': str(file)} if append_filename else yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
||||||
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'):
|
def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'):
|
||||||
|
@ -54,7 +54,7 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
self.model.names = self.data["names"]
|
self.model.names = self.data["names"]
|
||||||
|
|
||||||
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
||||||
model = DetectionModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"], verbose=verbose)
|
model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights, verbose)
|
model.load(weights, verbose)
|
||||||
return model
|
return model
|
||||||
|
@ -17,7 +17,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
|
|||||||
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
||||||
|
|
||||||
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
||||||
model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"], verbose=verbose)
|
model = SegmentationModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights, verbose)
|
model.load(weights, verbose)
|
||||||
return model
|
return model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user