mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Classify training cleanup (#33)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
2e9b18ce4e
commit
6fe8bead35
@ -24,13 +24,12 @@ from ultralytics.yolo.utils import LOGGER, ROOT
|
|||||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
||||||
from ultralytics.yolo.utils.modeling import get_model
|
from ultralytics.yolo.utils.modeling import get_model
|
||||||
|
|
||||||
CONFIG_PATH_ABS = ROOT / "yolo/utils/configs"
|
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yml"
|
||||||
DEFAULT_CONFIG = "defaults.yaml"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTrainer:
|
class BaseTrainer:
|
||||||
|
|
||||||
def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG, overrides={}):
|
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||||
self.console = LOGGER
|
self.console = LOGGER
|
||||||
self.args = self._get_config(config, overrides)
|
self.args = self._get_config(config, overrides)
|
||||||
self.validator = None
|
self.validator = None
|
||||||
|
@ -1,25 +1,27 @@
|
|||||||
model: null
|
# YOLO 🚀 by Ultralytics, GPL-3.0 license
|
||||||
data: null
|
# Default training settings and hyperparameters for medium-augmentation COCO training
|
||||||
|
|
||||||
# Training options
|
|
||||||
|
# Train settings -------------------------------------------------------------------------------------------------------
|
||||||
|
model: null # i.e. yolov5s.pt
|
||||||
|
data: null # i.e. coco128.yaml
|
||||||
epochs: 300
|
epochs: 300
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
img_size: 640
|
img_size: 640
|
||||||
nosave: False
|
nosave: False
|
||||||
cache: False # True/ram for ram, or disc
|
cache: False # True/ram, disk or False
|
||||||
device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
||||||
workers: 8
|
workers: 8
|
||||||
project: "ultralytics-yolo"
|
project: 'runs'
|
||||||
name: "exp" # TODO: make this informative, maybe exp{#number}_{datetime} ?
|
name: 'exp'
|
||||||
exist_ok: False
|
exist_ok: False
|
||||||
pretrained: False
|
pretrained: False
|
||||||
optimizer: "Adam" # choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
|
optimizer: 'SGD' # choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
|
||||||
verbose: False
|
verbose: False
|
||||||
seed: 0
|
seed: 0
|
||||||
local_rank: -1
|
local_rank: -1
|
||||||
#-----------------------------------#
|
|
||||||
|
|
||||||
# Hyper-parameters
|
# Hyperparameters ------------------------------------------------------------------------------------------------------
|
||||||
lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
|
||||||
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
|
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
|
||||||
momentum: 0.937 # SGD momentum/Adam beta1
|
momentum: 0.937 # SGD momentum/Adam beta1
|
||||||
@ -50,9 +52,8 @@ mosaic: 1.0 # image mosaic (probability)
|
|||||||
mixup: 0.0 # image mixup (probability)
|
mixup: 0.0 # image mixup (probability)
|
||||||
copy_paste: 0.0 # segment copy-paste (probability)
|
copy_paste: 0.0 # segment copy-paste (probability)
|
||||||
|
|
||||||
# Hydra configs -------------------------------------
|
# Hydra configs --------------------------------------------------------------------------------------------------------
|
||||||
# to disable hydra directory creation
|
|
||||||
hydra:
|
hydra:
|
||||||
output_subdir: null
|
output_subdir: null # disable hydra directory creation
|
||||||
run:
|
run:
|
||||||
dir: .
|
dir: .
|
@ -107,18 +107,17 @@ def parse_model(d, ch): # model_dict, input_channels(3)
|
|||||||
return nn.Sequential(*layers), sorted(save)
|
return nn.Sequential(*layers), sorted(save)
|
||||||
|
|
||||||
|
|
||||||
def get_model(model: str):
|
def get_model(model='s.pt', pretrained=True):
|
||||||
|
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||||
if model.endswith(".pt"):
|
if model.endswith(".pt"):
|
||||||
model = model.split(".")[0]
|
model = model.split(".")[0]
|
||||||
|
|
||||||
if Path(model + ".pt").is_file():
|
if Path(f"{model}.pt").is_file(): # local file
|
||||||
trained_model = torch.load(model + ".pt", map_location='cpu')
|
return torch.load(f"{model}.pt", map_location='cpu')
|
||||||
elif model in torchvision.models.__dict__: # try torch hub classifier models
|
elif model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
|
||||||
trained_model = torch.hub.load("pytorch/vision", model, pretrained=True)
|
return torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
||||||
else:
|
else: # Ultralytics assets
|
||||||
model_ckpt = attempt_download(model + ".pt") # try ultralytics assets
|
return torch.load(attempt_download(f"{model}.pt"), map_location='cpu')
|
||||||
trained_model = torch.load(model_ckpt, map_location='cpu')
|
|
||||||
return trained_model
|
|
||||||
|
|
||||||
|
|
||||||
def yaml_load(file='data.yaml'):
|
def yaml_load(file='data.yaml'):
|
||||||
|
@ -4,13 +4,13 @@ from pathlib import Path
|
|||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
|
||||||
|
|
||||||
from ultralytics.yolo import v8
|
from ultralytics.yolo import v8
|
||||||
from ultralytics.yolo.data import build_classification_dataloader
|
from ultralytics.yolo.data import build_classification_dataloader
|
||||||
from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG, BaseTrainer
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
||||||
from ultralytics.yolo.utils.downloads import download
|
from ultralytics.yolo.utils.downloads import download
|
||||||
from ultralytics.yolo.utils.files import WorkingDirectory
|
from ultralytics.yolo.utils.files import WorkingDirectory
|
||||||
|
from ultralytics.yolo.utils.loggers import colorstr
|
||||||
from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first
|
from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first
|
||||||
|
|
||||||
|
|
||||||
@ -30,8 +30,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
else:
|
else:
|
||||||
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
|
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
|
||||||
download(url, dir=data_dir.parent)
|
download(url, dir=data_dir.parent)
|
||||||
# TODO: add colorstr
|
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
|
||||||
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n"
|
|
||||||
self.console.info(s)
|
self.console.info(s)
|
||||||
train_set = data_dir / "train"
|
train_set = data_dir / "train"
|
||||||
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
|
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
|
||||||
@ -48,7 +47,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
return torch.nn.functional.cross_entropy(preds, targets)
|
return torch.nn.functional.cross_entropy(preds, targets)
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0])
|
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.stem)
|
||||||
def train(cfg):
|
def train(cfg):
|
||||||
cfg.model = cfg.model or "resnet18"
|
cfg.model = cfg.model or "resnet18"
|
||||||
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
|
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
|
||||||
@ -59,7 +58,7 @@ def train(cfg):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
CLI usage:
|
CLI usage:
|
||||||
python ../path/to/train.py args.epochs=10 args.project="name" hyps.lr0=0.1
|
python path/to/train.py epochs=10 project=PROJECT lr0=0.1
|
||||||
|
|
||||||
TODO:
|
TODO:
|
||||||
Direct cli support, i.e, yolov8 classify_train args.epochs 10
|
Direct cli support, i.e, yolov8 classify_train args.epochs 10
|
||||||
|
Loading…
x
Reference in New Issue
Block a user