mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-07-07 22:04:53 +08:00
Fix load and resume and update autodownload endpoint (#136)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
6d5123297e
commit
82c849c163
@ -82,7 +82,7 @@ class YOLO:
|
|||||||
self.ckpt_path = weights
|
self.ckpt_path = weights
|
||||||
self.task = self.model.args["task"]
|
self.task = self.model.args["task"]
|
||||||
self.overrides = self.model.args
|
self.overrides = self.model.args
|
||||||
self.overrides["device"] = '' # reset device
|
self._reset_ckpt_args(self.overrides)
|
||||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
||||||
self._guess_ops_from_task(self.task)
|
self._guess_ops_from_task(self.task)
|
||||||
|
|
||||||
@ -199,27 +199,6 @@ class YOLO:
|
|||||||
|
|
||||||
self.trainer.train()
|
self.trainer.train()
|
||||||
|
|
||||||
def resume(self, task=None, model=None):
|
|
||||||
"""
|
|
||||||
Resume a training task. Requires either `task` or `model`. `model` takes the higher precedence.
|
|
||||||
Args:
|
|
||||||
task (str): The task type you want to resume. Automatically finds the last run to resume if `model` is not specified.
|
|
||||||
model (str): The model checkpoint to resume from. If not found, the last run of the given task type is resumed.
|
|
||||||
If `model` is specified
|
|
||||||
"""
|
|
||||||
if task:
|
|
||||||
if task.lower() not in MODEL_MAP:
|
|
||||||
raise SyntaxError(f"unrecognised task - {task}. Supported tasks are {MODEL_MAP.keys()}")
|
|
||||||
else:
|
|
||||||
ckpt = torch.load(model, map_location="cpu")
|
|
||||||
task = ckpt["train_args"]["task"]
|
|
||||||
del ckpt
|
|
||||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task(
|
|
||||||
task=task.lower())
|
|
||||||
self.trainer = self.TrainerClass(overrides={"task": task.lower(), "resume": model or True})
|
|
||||||
|
|
||||||
self.trainer.train()
|
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
|
|
||||||
@ -240,3 +219,10 @@ class YOLO:
|
|||||||
|
|
||||||
def forward(self, imgs):
|
def forward(self, imgs):
|
||||||
return self.__call__(imgs)
|
return self.__call__(imgs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reset_ckpt_args(args):
|
||||||
|
args.pop("device", None)
|
||||||
|
args.pop("project", None)
|
||||||
|
args.pop("name", None)
|
||||||
|
args.pop("batch_size", None)
|
||||||
|
@ -367,7 +367,7 @@ class BaseTrainer:
|
|||||||
if not pretrained:
|
if not pretrained:
|
||||||
model = check_file(model)
|
model = check_file(model)
|
||||||
ckpt = self.load_ckpt(model) if pretrained else None
|
ckpt = self.load_ckpt(model) if pretrained else None
|
||||||
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt) # model
|
self.model = self.load_model(model_cfg=None if pretrained else model, weights=ckpt["model"]) # model
|
||||||
return ckpt
|
return ckpt
|
||||||
|
|
||||||
def load_ckpt(self, ckpt):
|
def load_ckpt(self, ckpt):
|
||||||
|
@ -45,11 +45,12 @@ def is_url(url, check=True):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
|
def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||||
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
|
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
|
||||||
|
|
||||||
def github_assets(repository, version='latest'):
|
def github_assets(repository, version='latest'):
|
||||||
# Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov5m.pt', ...])
|
# Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov5m.pt', ...])
|
||||||
|
# Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])
|
||||||
if version != 'latest':
|
if version != 'latest':
|
||||||
version = f'tags/{version}' # i.e. tags/v6.2
|
version = f'tags/{version}' # i.e. tags/v6.2
|
||||||
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
|
response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
|
||||||
@ -70,6 +71,7 @@ def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
|
|||||||
|
|
||||||
# GitHub assets
|
# GitHub assets
|
||||||
assets = [f'yolov5{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default
|
assets = [f'yolov5{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default
|
||||||
|
assets = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] # default
|
||||||
try:
|
try:
|
||||||
tag, assets = github_assets(repo, release)
|
tag, assets = github_assets(repo, release)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -54,12 +54,9 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
self.model.names = self.data["names"]
|
self.model.names = self.data["names"]
|
||||||
|
|
||||||
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
||||||
model = DetectionModel(model_cfg or getattr(weights, 'yaml', None) or weights['model'].yaml,
|
model = DetectionModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
|
||||||
ch=3,
|
|
||||||
nc=self.data["nc"],
|
|
||||||
verbose=verbose)
|
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights['model'] if isinstance(weights, dict) else weights, verbose)
|
model.load(weights, verbose)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
|
@ -17,12 +17,9 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
|
|||||||
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
class SegmentationTrainer(v8.detect.DetectionTrainer):
|
||||||
|
|
||||||
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
def load_model(self, model_cfg=None, weights=None, verbose=True):
|
||||||
model = SegmentationModel(model_cfg or getattr(weights, 'yaml', None) or weights['model'].yaml,
|
model = SegmentationModel(model_cfg or weights.yaml, ch=3, nc=self.data["nc"], verbose=verbose)
|
||||||
ch=3,
|
|
||||||
nc=self.data["nc"],
|
|
||||||
verbose=verbose)
|
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights['model'] if isinstance(weights, dict) else weights, verbose)
|
model.load(weights, verbose)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user