mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-29 02:14:22 +08:00
Fix CLI detect and segment resume (#134)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
c5c86a3acd
commit
6d5123297e
@ -47,6 +47,7 @@ class YOLO:
|
|||||||
self.trainer = None # trainer object
|
self.trainer = None # trainer object
|
||||||
self.task = None # task type
|
self.task = None # task type
|
||||||
self.ckpt = None # if loaded from *.pt
|
self.ckpt = None # if loaded from *.pt
|
||||||
|
self.ckpt_path = None
|
||||||
self.cfg = None # if loaded from *.yaml
|
self.cfg = None # if loaded from *.yaml
|
||||||
self.overrides = {} # overrides for trainer object
|
self.overrides = {} # overrides for trainer object
|
||||||
self.init_disabled = False # disable model initialization
|
self.init_disabled = False # disable model initialization
|
||||||
@ -78,6 +79,7 @@ class YOLO:
|
|||||||
weights (str): model checkpoint to be loaded
|
weights (str): model checkpoint to be loaded
|
||||||
"""
|
"""
|
||||||
self.model = attempt_load_weights(weights)
|
self.model = attempt_load_weights(weights)
|
||||||
|
self.ckpt_path = weights
|
||||||
self.task = self.model.args["task"]
|
self.task = self.model.args["task"]
|
||||||
self.overrides = self.model.args
|
self.overrides = self.model.args
|
||||||
self.overrides["device"] = '' # reset device
|
self.overrides["device"] = '' # reset device
|
||||||
@ -177,8 +179,8 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
if not self.model:
|
if not self.model:
|
||||||
raise AttributeError("model not initialized. Use .new() or .load()")
|
raise AttributeError("model not initialized. Use .new() or .load()")
|
||||||
|
overrides = self.overrides.copy()
|
||||||
overrides = kwargs
|
overrides.update(kwargs)
|
||||||
if kwargs.get("cfg"):
|
if kwargs.get("cfg"):
|
||||||
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
||||||
overrides = yaml_load(check_yaml(kwargs["cfg"]))
|
overrides = yaml_load(check_yaml(kwargs["cfg"]))
|
||||||
@ -187,10 +189,13 @@ class YOLO:
|
|||||||
if not overrides.get("data"):
|
if not overrides.get("data"):
|
||||||
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
|
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
|
||||||
|
|
||||||
|
if overrides.get("resume"):
|
||||||
|
overrides["resume"] = self.ckpt_path
|
||||||
self.trainer = self.TrainerClass(overrides=overrides)
|
self.trainer = self.TrainerClass(overrides=overrides)
|
||||||
self.trainer.model = self.trainer.load_model(weights=self.model,
|
if not overrides.get("resume"):
|
||||||
model_cfg=self.model.yaml if self.task != "classify" else None)
|
self.trainer.model = self.trainer.load_model(weights=self.model,
|
||||||
self.model = self.trainer.model # override here to save memory
|
model_cfg=self.model.yaml if self.task != "classify" else None)
|
||||||
|
self.model = self.trainer.model # override here to save memory
|
||||||
|
|
||||||
self.trainer.train()
|
self.trainer.train()
|
||||||
|
|
||||||
|
@ -17,9 +17,12 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
|
|||||||
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
||||||
|
|
||||||
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
||||||
model = SegmentationModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
|
model = SegmentationModel(model_cfg or getattr(weights, 'yaml', None) or weights['model'].yaml,
|
||||||
|
ch=3,
|
||||||
|
nc=self.data["nc"],
|
||||||
|
verbose=verbose)
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights, verbose)
|
model.load(weights['model'] if isinstance(weights, dict) else weights, verbose)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user