mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
update model initialization design, supports custom data/num_classes (#44)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1f3aad86c1
commit
832ea56eb4
2
.github/workflows/ci.yaml
vendored
2
.github/workflows/ci.yaml
vendored
@ -94,7 +94,7 @@ jobs:
|
|||||||
- name: Test segmentation
|
- name: Test segmentation
|
||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
run: |
|
run: |
|
||||||
python ultralytics/yolo/v8/segment/train.py cfg=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64
|
python ultralytics/yolo/v8/segment/train.py model=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64
|
||||||
- name: Test classification
|
- name: Test classification
|
||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
run: |
|
run: |
|
||||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -131,3 +131,4 @@ dmypy.json
|
|||||||
# datasets and projects
|
# datasets and projects
|
||||||
datasets/
|
datasets/
|
||||||
ultralytics-yolo/
|
ultralytics-yolo/
|
||||||
|
runs/
|
@ -63,10 +63,8 @@ class BaseTrainer:
|
|||||||
else:
|
else:
|
||||||
self.data = check_dataset(self.data)
|
self.data = check_dataset(self.data)
|
||||||
self.trainset, self.testset = self.get_dataset(self.data)
|
self.trainset, self.testset = self.get_dataset(self.data)
|
||||||
if self.args.cfg is not None:
|
if self.args.model:
|
||||||
self.model = self.load_cfg(check_file(self.args.cfg))
|
self.model = self.get_model(self.args.model, self.data)
|
||||||
if self.args.model is not None:
|
|
||||||
self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
|
|
||||||
|
|
||||||
# epoch level metrics
|
# epoch level metrics
|
||||||
self.metrics = {} # handle metrics returned by validator
|
self.metrics = {} # handle metrics returned by validator
|
||||||
@ -261,20 +259,20 @@ class BaseTrainer:
|
|||||||
"""
|
"""
|
||||||
return data["train"], data["val"]
|
return data["train"], data["val"]
|
||||||
|
|
||||||
def get_model(self, model, pretrained):
|
def get_model(self, model: str, data: Dict):
|
||||||
"""
|
"""
|
||||||
load/create/download model for any task
|
load/create/download model for any task
|
||||||
"""
|
"""
|
||||||
model = get_model(model)
|
pretrained = False
|
||||||
for m in model.modules():
|
if not str(model).endswith(".yaml"):
|
||||||
if not pretrained and hasattr(m, 'reset_parameters'):
|
pretrained = True
|
||||||
m.reset_parameters()
|
weights = get_model(model) # rename this to something less confusing?
|
||||||
for p in model.parameters():
|
model = self.load_model(model_cfg=model if not pretrained else None,
|
||||||
p.requires_grad = True
|
weights=weights if pretrained else None,
|
||||||
|
data=self.data)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def load_cfg(self, cfg):
|
def load_model(self, model_cfg, weights, data):
|
||||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
|
@ -3,8 +3,7 @@
|
|||||||
|
|
||||||
|
|
||||||
# Train settings -------------------------------------------------------------------------------------------------------
|
# Train settings -------------------------------------------------------------------------------------------------------
|
||||||
model: null # i.e. yolov5s.pt
|
model: null # i.e. yolov5s.pt, yolo.yaml
|
||||||
cfg: null # i.e. yolov5s.yaml
|
|
||||||
data: null # i.e. coco128.yaml
|
data: null # i.e. coco128.yaml
|
||||||
epochs: 300
|
epochs: 300
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
@ -70,6 +69,7 @@ mosaic: 1.0 # image mosaic (probability)
|
|||||||
mixup: 0.0 # image mixup (probability)
|
mixup: 0.0 # image mixup (probability)
|
||||||
copy_paste: 0.0 # segment copy-paste (probability)
|
copy_paste: 0.0 # segment copy-paste (probability)
|
||||||
label_smoothing: 0.0
|
label_smoothing: 0.0
|
||||||
|
# anchors: 3
|
||||||
|
|
||||||
# Hydra configs --------------------------------------------------------------------------------------------------------
|
# Hydra configs --------------------------------------------------------------------------------------------------------
|
||||||
hydra:
|
hydra:
|
||||||
|
@ -140,8 +140,3 @@ def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1
|
|||||||
else:
|
else:
|
||||||
for u in [url] if isinstance(url, (str, Path)) else url:
|
for u in [url] if isinstance(url, (str, Path)) else url:
|
||||||
download_one(u, dir)
|
download_one(u, dir)
|
||||||
|
|
||||||
|
|
||||||
def get_model(model: str):
|
|
||||||
# check for local weights
|
|
||||||
pass
|
|
||||||
|
@ -66,7 +66,7 @@ class BaseModel(nn.Module):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def load(self, weights):
|
def load(self, weights):
|
||||||
# Force all tasks implement this function
|
# Force all tasks to implement this function
|
||||||
raise NotImplementedError("This function needs to be implemented by derived classes!")
|
raise NotImplementedError("This function needs to be implemented by derived classes!")
|
||||||
|
|
||||||
|
|
||||||
@ -169,10 +169,10 @@ class DetectionModel(BaseModel):
|
|||||||
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
|
||||||
|
|
||||||
def load(self, weights):
|
def load(self, weights):
|
||||||
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
|
csd = weights['model'].float().state_dict() # checkpoint state_dict as FP32
|
||||||
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
|
||||||
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
|
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
|
||||||
self.load_state_dict(csd, strict=False) # load
|
self.load_state_dict(csd, strict=False) # load
|
||||||
|
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from {weights}')
|
||||||
|
|
||||||
|
|
||||||
class SegmentationModel(DetectionModel):
|
class SegmentationModel(DetectionModel):
|
||||||
@ -203,11 +203,33 @@ class ClassificationModel(BaseModel):
|
|||||||
self.nc = nc
|
self.nc = nc
|
||||||
|
|
||||||
def _from_yaml(self, cfg):
|
def _from_yaml(self, cfg):
|
||||||
# Create a YOLOv5 classification model from a *.yaml file
|
# TODO: Create a YOLOv5 classification model from a *.yaml file
|
||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
def load(self, weights):
|
def load(self, weights):
|
||||||
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
|
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
||||||
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
|
csd = model.float().state_dict()
|
||||||
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
|
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
|
||||||
self.load_state_dict(csd, strict=False) # load
|
self.load_state_dict(csd, strict=False) # load
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reshape_outputs(model, nc):
|
||||||
|
# Update a TorchVision classification model to class count 'n' if required
|
||||||
|
from ultralytics.yolo.utils.modeling.modules import Classify
|
||||||
|
name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
|
||||||
|
if isinstance(m, Classify): # YOLO Classify() head
|
||||||
|
if m.linear.out_features != nc:
|
||||||
|
m.linear = nn.Linear(m.linear.in_features, nc)
|
||||||
|
elif isinstance(m, nn.Linear): # ResNet, EfficientNet
|
||||||
|
if m.out_features != nc:
|
||||||
|
setattr(model, name, nn.Linear(m.in_features, nc))
|
||||||
|
elif isinstance(m, nn.Sequential):
|
||||||
|
types = [type(x) for x in m]
|
||||||
|
if nn.Linear in types:
|
||||||
|
i = types.index(nn.Linear) # nn.Linear index
|
||||||
|
if m[i].out_features != nc:
|
||||||
|
m[i] = nn.Linear(m[i].in_features, nc)
|
||||||
|
elif nn.Conv2d in types:
|
||||||
|
i = types.index(nn.Conv2d) # nn.Conv2d index
|
||||||
|
if m[i].out_channels != nc:
|
||||||
|
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias)
|
||||||
|
@ -1,26 +1,27 @@
|
|||||||
import subprocess
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics.yolo import v8
|
from ultralytics.yolo import v8
|
||||||
from ultralytics.yolo.data import build_classification_dataloader
|
from ultralytics.yolo.data import build_classification_dataloader
|
||||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
||||||
from ultralytics.yolo.utils import colorstr
|
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel
|
||||||
from ultralytics.yolo.utils.downloads import download
|
|
||||||
from ultralytics.yolo.utils.files import WorkingDirectory
|
|
||||||
from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first
|
|
||||||
|
|
||||||
|
|
||||||
# BaseTrainer python usage
|
|
||||||
class ClassificationTrainer(BaseTrainer):
|
class ClassificationTrainer(BaseTrainer):
|
||||||
|
|
||||||
|
def load_model(self, model_cfg, weights, data):
|
||||||
|
# TODO: why treat clf models as unique. We should have clf yamls?
|
||||||
|
if weights and not weights.__class__.__name__.startswith("yolo"): # torchvision
|
||||||
|
model = weights
|
||||||
|
else:
|
||||||
|
model = ClassificationModel(model_cfg, weights, data["nc"])
|
||||||
|
ClassificationModel.reshape_outputs(model, data["nc"])
|
||||||
|
return model
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size=None, rank=0):
|
def get_dataloader(self, dataset_path, batch_size=None, rank=0):
|
||||||
return build_classification_dataloader(path=dataset_path,
|
return build_classification_dataloader(path=dataset_path,
|
||||||
imgsz=self.args.img_size,
|
imgsz=self.args.img_size,
|
||||||
batch_size=self.args.batch_size,
|
batch_size=batch_size,
|
||||||
rank=rank)
|
rank=rank)
|
||||||
|
|
||||||
def preprocess_batch(self, batch):
|
def preprocess_batch(self, batch):
|
||||||
|
@ -10,12 +10,11 @@ import torch.nn.functional as F
|
|||||||
from ultralytics.yolo import v8
|
from ultralytics.yolo import v8
|
||||||
from ultralytics.yolo.data import build_dataloader
|
from ultralytics.yolo.data import build_dataloader
|
||||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
||||||
from ultralytics.yolo.utils.downloads import download
|
from ultralytics.yolo.utils.anchors import check_anchors
|
||||||
from ultralytics.yolo.utils.files import WorkingDirectory
|
|
||||||
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
|
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
|
||||||
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
|
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
|
||||||
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
|
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
|
||||||
from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, de_parallel, torch_distributed_zero_first
|
from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||||
|
|
||||||
|
|
||||||
# BaseTrainer python usage
|
# BaseTrainer python usage
|
||||||
@ -45,8 +44,15 @@ class SegmentationTrainer(BaseTrainer):
|
|||||||
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def load_cfg(self, cfg):
|
def load_model(self, model_cfg, weights, data):
|
||||||
return SegmentationModel(cfg, nc=80)
|
model = SegmentationModel(model_cfg if model_cfg else weights["model"].yaml,
|
||||||
|
ch=3,
|
||||||
|
nc=data["nc"],
|
||||||
|
anchors=self.args.get("anchors"))
|
||||||
|
check_anchors(model, self.args.anchor_t, self.args.img_size)
|
||||||
|
if weights:
|
||||||
|
model.load(weights)
|
||||||
|
return model
|
||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)
|
return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)
|
||||||
@ -232,7 +238,7 @@ class SegmentationTrainer(BaseTrainer):
|
|||||||
|
|
||||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||||
def train(cfg):
|
def train(cfg):
|
||||||
cfg.cfg = v8.ROOT / "models/yolov5n-seg.yaml"
|
cfg.model = v8.ROOT / "models/yolov5n-seg.yaml"
|
||||||
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
|
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
|
||||||
trainer = SegmentationTrainer(cfg)
|
trainer = SegmentationTrainer(cfg)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user