mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
standalone val (#56)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3a241e4cea
commit
5a52e7663a
3
.github/workflows/ci.yaml
vendored
3
.github/workflows/ci.yaml
vendored
@ -93,9 +93,12 @@ jobs:
|
|||||||
echo "TODO"
|
echo "TODO"
|
||||||
- name: Test segmentation
|
- name: Test segmentation
|
||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
|
# TODO: redo val test without hardcoded weights
|
||||||
run: |
|
run: |
|
||||||
yolo task=segment mode=train model=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64
|
yolo task=segment mode=train model=yolov5n-seg.yaml data=coco128-seg.yaml epochs=1 img_size=64
|
||||||
|
yolo task=segment mode=val model=runs/exp/weights/last.pt data=coco128-seg.yaml img_size=64
|
||||||
- name: Test classification
|
- name: Test classification
|
||||||
shell: bash # for Windows compatibility
|
shell: bash # for Windows compatibility
|
||||||
run: |
|
run: |
|
||||||
yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 img_size=32
|
yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 img_size=32
|
||||||
|
yolo task=classify mode=val model=runs/exp2/weights/last.pt data=mnist160
|
@ -208,6 +208,9 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||||||
sample = self.torch_transforms(im)
|
sample = self.torch_transforms(im)
|
||||||
return OrderedDict(img=sample, cls=j)
|
return OrderedDict(img=sample, cls=j)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
|
||||||
# TODO: support semantic segmentation
|
# TODO: support semantic segmentation
|
||||||
class SemanticDataset(BaseDataset):
|
class SemanticDataset(BaseDataset):
|
||||||
|
19
ultralytics/yolo/engine/exporter.py
Normal file
19
ultralytics/yolo/engine/exporter.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
def export_formats():
|
||||||
|
# YOLOv5 export formats
|
||||||
|
x = [
|
||||||
|
['PyTorch', '-', '.pt', True, True],
|
||||||
|
['TorchScript', 'torchscript', '.torchscript', True, True],
|
||||||
|
['ONNX', 'onnx', '.onnx', True, True],
|
||||||
|
['OpenVINO', 'openvino', '_openvino_model', True, False],
|
||||||
|
['TensorRT', 'engine', '.engine', False, True],
|
||||||
|
['CoreML', 'coreml', '.mlmodel', True, False],
|
||||||
|
['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
|
||||||
|
['TensorFlow GraphDef', 'pb', '.pb', True, True],
|
||||||
|
['TensorFlow Lite', 'tflite', '.tflite', True, False],
|
||||||
|
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
|
||||||
|
['TensorFlow.js', 'tfjs', '_web_model', False, False],
|
||||||
|
['PaddlePaddle', 'paddle', '_paddle_model', True, True],]
|
||||||
|
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
|
@ -25,7 +25,7 @@ import ultralytics.yolo.utils as utils
|
|||||||
import ultralytics.yolo.utils.callbacks as callbacks
|
import ultralytics.yolo.utils.callbacks as callbacks
|
||||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||||
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
|
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr
|
||||||
from ultralytics.yolo.utils.checks import print_args
|
from ultralytics.yolo.utils.checks import check_file, print_args
|
||||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
||||||
from ultralytics.yolo.utils.modeling import get_model
|
from ultralytics.yolo.utils.modeling import get_model
|
||||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
||||||
@ -299,13 +299,16 @@ class BaseTrainer:
|
|||||||
"""
|
"""
|
||||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized
|
Get train, val path from data dict if it exists. Returns None if data format is not recognized
|
||||||
"""
|
"""
|
||||||
return data["train"], data["val"]
|
return data["train"], data.get("val") or data.get("test")
|
||||||
|
|
||||||
def get_model(self, model: Union[str, Path]):
|
def get_model(self, model: Union[str, Path]):
|
||||||
"""
|
"""
|
||||||
load/create/download model for any task
|
load/create/download model for any task
|
||||||
"""
|
"""
|
||||||
pretrained = not str(model).endswith(".yaml")
|
pretrained = True
|
||||||
|
if str(model).endswith(".yaml"):
|
||||||
|
model = check_file(model)
|
||||||
|
pretrained = False
|
||||||
return self.load_model(model_cfg=None if pretrained else model,
|
return self.load_model(model_cfg=None if pretrained else model,
|
||||||
weights=get_model(model) if pretrained else None,
|
weights=get_model(model) if pretrained else None,
|
||||||
data=self.data) # model
|
data=self.data) # model
|
||||||
@ -376,7 +379,7 @@ class BaseTrainer:
|
|||||||
"""
|
"""
|
||||||
To set or update model parameters before training.
|
To set or update model parameters before training.
|
||||||
"""
|
"""
|
||||||
pass
|
self.model.names = self.data["names"]
|
||||||
|
|
||||||
def build_targets(self, preds, targets):
|
def build_targets(self, preds, targets):
|
||||||
pass
|
pass
|
||||||
|
@ -5,11 +5,14 @@ import torch
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||||
from ultralytics.yolo.utils import TQDM_BAR_FORMAT
|
from ultralytics.yolo.utils import LOGGER, TQDM_BAR_FORMAT
|
||||||
from ultralytics.yolo.utils.files import increment_path
|
from ultralytics.yolo.utils.files import increment_path
|
||||||
|
from ultralytics.yolo.utils.modeling import get_model
|
||||||
|
from ultralytics.yolo.utils.modeling.autobackend import AutoBackend
|
||||||
from ultralytics.yolo.utils.ops import Profile
|
from ultralytics.yolo.utils.ops import Profile
|
||||||
from ultralytics.yolo.utils.torch_utils import de_parallel, select_device
|
from ultralytics.yolo.utils.torch_utils import check_img_size, de_parallel, select_device
|
||||||
|
|
||||||
|
|
||||||
class BaseValidator:
|
class BaseValidator:
|
||||||
@ -17,17 +20,18 @@ class BaseValidator:
|
|||||||
Base validator class.
|
Base validator class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||||
self.dataloader = dataloader
|
self.dataloader = dataloader
|
||||||
self.pbar = pbar
|
self.pbar = pbar
|
||||||
self.logger = logger or logging.getLogger()
|
self.logger = logger or LOGGER
|
||||||
self.args = args or OmegaConf.load(DEFAULT_CONFIG)
|
self.args = args or OmegaConf.load(DEFAULT_CONFIG)
|
||||||
self.device = select_device(self.args.device, dataloader.batch_size)
|
self.model = None
|
||||||
self.save_dir = save_dir if save_dir is not None else \
|
self.data = None
|
||||||
increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
self.device = None
|
||||||
self.cuda = self.device.type != 'cpu'
|
|
||||||
self.batch_i = None
|
self.batch_i = None
|
||||||
self.training = True
|
self.training = True
|
||||||
|
self.save_dir = save_dir if save_dir is not None else \
|
||||||
|
increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
||||||
|
|
||||||
def __call__(self, trainer=None, model=None):
|
def __call__(self, trainer=None, model=None):
|
||||||
"""
|
"""
|
||||||
@ -36,14 +40,35 @@ class BaseValidator:
|
|||||||
"""
|
"""
|
||||||
self.training = trainer is not None
|
self.training = trainer is not None
|
||||||
if self.training:
|
if self.training:
|
||||||
|
self.device = trainer.device
|
||||||
|
self.data = trainer.data
|
||||||
model = trainer.ema.ema or trainer.model
|
model = trainer.ema.ema or trainer.model
|
||||||
self.args.half &= self.device.type != 'cpu'
|
self.args.half &= self.device.type != 'cpu'
|
||||||
model = model.half() if self.args.half else model.float()
|
model = model.half() if self.args.half else model.float()
|
||||||
|
self.model = model
|
||||||
loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
||||||
else: # TODO: handle this when detectMultiBackend is supported
|
else: # TODO: handle this when detectMultiBackend is supported
|
||||||
assert model is not None, "Either trainer or model is needed for validation"
|
assert model is not None, "Either trainer or model is needed for validation"
|
||||||
# model = DetectMultiBacked(model)
|
self.device = select_device(self.args.device, self.args.batch_size)
|
||||||
# TODO: implement init_model_attributes()
|
self.args.half &= self.device.type != 'cpu'
|
||||||
|
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, fp16=self.args.half)
|
||||||
|
self.model = model
|
||||||
|
stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
|
||||||
|
imgsz = check_img_size(self.args.img_size, s=stride)
|
||||||
|
if engine:
|
||||||
|
self.args.batch_size = model.batch_size
|
||||||
|
else:
|
||||||
|
self.device = model.device
|
||||||
|
if not (pt or jit):
|
||||||
|
self.args.batch_size = 1 # export.py models default to batch-size 1
|
||||||
|
self.logger.info(
|
||||||
|
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||||
|
|
||||||
|
if self.args.data.endswith(".yaml"):
|
||||||
|
data = check_dataset_yaml(self.args.data)
|
||||||
|
else:
|
||||||
|
data = check_dataset(self.args.data)
|
||||||
|
self.dataloader = self.get_dataloader(data.get("val") or data.set("test"), self.args.batch_size)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -101,6 +126,9 @@ class BaseValidator:
|
|||||||
return stats | trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val") \
|
return stats | trainer.label_loss_items(loss.cpu() / len(self.dataloader), prefix="val") \
|
||||||
if self.training else stats
|
if self.training else stats
|
||||||
|
|
||||||
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
|
raise Exception("get_dataloder function not implemented for this validator")
|
||||||
|
|
||||||
def preprocess(self, batch):
|
def preprocess(self, batch):
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@ -28,17 +28,22 @@ single_cls: False # train multi-class data as single-class
|
|||||||
image_weights: False # use weighted image selection for training
|
image_weights: False # use weighted image selection for training
|
||||||
rect: False # support rectangular training
|
rect: False # support rectangular training
|
||||||
cos_lr: False # Use cosine LR scheduler
|
cos_lr: False # Use cosine LR scheduler
|
||||||
overlap_mask: True # Segmentation masks overlap
|
# Segmentation
|
||||||
mask_ratio: 4 # Segmentation mask downsample ratio
|
overlap_mask: True # masks overlap
|
||||||
noval: False
|
mask_ratio: 4 # mask downsample ratio
|
||||||
|
# Classification
|
||||||
|
dropout: False # use dropout
|
||||||
|
|
||||||
|
|
||||||
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
# Val/Test settings ----------------------------------------------------------------------------------------------------
|
||||||
|
noval: False
|
||||||
save_json: False
|
save_json: False
|
||||||
save_hybrid: False
|
save_hybrid: False
|
||||||
conf_thres: 0.001
|
conf_thres: 0.001
|
||||||
iou_thres: 0.6
|
iou_thres: 0.6
|
||||||
max_det: 300
|
max_det: 300
|
||||||
half: True
|
half: True
|
||||||
|
dnn: False # use OpenCV DNN for ONNX inference
|
||||||
plots: False
|
plots: False
|
||||||
save_txt: False
|
save_txt: False
|
||||||
|
|
||||||
|
@ -113,8 +113,8 @@ def get_model(model='s.pt', pretrained=True):
|
|||||||
model = model.split(".")[0]
|
model = model.split(".")[0]
|
||||||
|
|
||||||
if Path(f"{model}.pt").is_file(): # local file
|
if Path(f"{model}.pt").is_file(): # local file
|
||||||
return torch.load(f"{model}.pt", map_location='cpu')
|
return attempt_load_weights(f"{model}.pt", device='cpu')
|
||||||
elif model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
|
elif model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
|
||||||
return torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
return torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
||||||
else: # Ultralytics assets
|
else: # Ultralytics assets
|
||||||
return torch.load(attempt_download(f"{model}.pt"), map_location='cpu')
|
return attempt_load_weights(f"{model}.pt", device='cpu')
|
||||||
|
@ -304,7 +304,7 @@ class AutoBackend(nn.Module):
|
|||||||
def _model_type(p='path/to/model.pt'):
|
def _model_type(p='path/to/model.pt'):
|
||||||
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
|
||||||
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
|
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
|
||||||
from export import export_formats
|
from ultralytics.yolo.engine.exporter import export_formats
|
||||||
sf = list(export_formats().Suffix) # export suffixes
|
sf = list(export_formats().Suffix) # export suffixes
|
||||||
if not is_url(p, check=False):
|
if not is_url(p, check=False):
|
||||||
check_suffix(p, sf) # checks
|
check_suffix(p, sf) # checks
|
||||||
|
@ -172,7 +172,7 @@ class DetectionModel(BaseModel):
|
|||||||
csd = weights['model'].float().state_dict() # checkpoint state_dict as FP32
|
csd = weights['model'].float().state_dict() # checkpoint state_dict as FP32
|
||||||
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
|
csd = intersect_state_dicts(csd, self.state_dict()) # intersect
|
||||||
self.load_state_dict(csd, strict=False) # load
|
self.load_state_dict(csd, strict=False) # load
|
||||||
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from {weights}')
|
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
|
||||||
|
|
||||||
|
|
||||||
class SegmentationModel(DetectionModel):
|
class SegmentationModel(DetectionModel):
|
||||||
|
@ -164,6 +164,25 @@ def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
|
|||||||
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
||||||
|
|
||||||
|
|
||||||
|
def check_img_size(imgsz, s=32, floor=0):
|
||||||
|
# Verify image size is a multiple of stride s in each dimension
|
||||||
|
if isinstance(imgsz, int): # integer i.e. img_size=640
|
||||||
|
new_size = max(make_divisible(imgsz, int(s)), floor)
|
||||||
|
else: # list i.e. img_size=[640, 480]
|
||||||
|
imgsz = list(imgsz) # convert to list if tuple
|
||||||
|
new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
|
||||||
|
if new_size != imgsz:
|
||||||
|
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
|
||||||
|
return new_size
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(x, divisor):
|
||||||
|
# Returns nearest x divisible by divisor
|
||||||
|
if isinstance(divisor, torch.Tensor):
|
||||||
|
divisor = int(divisor.max()) # to int
|
||||||
|
return math.ceil(x / divisor) * divisor
|
||||||
|
|
||||||
|
|
||||||
def copy_attr(a, b, include=(), exclude=()):
|
def copy_attr(a, b, include=(), exclude=()):
|
||||||
# Copy attributes from b to a, options to only include [...] and to exclude [...]
|
# Copy attributes from b to a, options to only include [...] and to exclude [...]
|
||||||
for k, v in b.__dict__.items():
|
for k, v in b.__dict__.items():
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
|
from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train
|
||||||
from ultralytics.yolo.v8.classify.val import ClassificationValidator
|
from ultralytics.yolo.v8.classify.val import ClassificationValidator, val
|
||||||
|
|
||||||
__all__ = ["train"]
|
__all__ = ["train"]
|
||||||
|
@ -19,6 +19,13 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
else:
|
else:
|
||||||
model = ClassificationModel(model_cfg, weights, data["nc"])
|
model = ClassificationModel(model_cfg, weights, data["nc"])
|
||||||
ClassificationModel.reshape_outputs(model, data["nc"])
|
ClassificationModel.reshape_outputs(model, data["nc"])
|
||||||
|
for m in model.modules():
|
||||||
|
if not weights and hasattr(m, 'reset_parameters'):
|
||||||
|
m.reset_parameters()
|
||||||
|
if isinstance(m, torch.nn.Dropout) and self.args.dropout is not None:
|
||||||
|
m.p = self.args.dropout # set dropout
|
||||||
|
for p in model.parameters():
|
||||||
|
p.requires_grad = True # for training
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
|
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ultralytics.yolo.data import build_classification_dataloader
|
||||||
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||||
from ultralytics.yolo.engine.validator import BaseValidator
|
from ultralytics.yolo.engine.validator import BaseValidator
|
||||||
|
|
||||||
|
|
||||||
@ -24,6 +27,21 @@ class ClassificationValidator(BaseValidator):
|
|||||||
top1, top5 = acc.mean(0).tolist()
|
top1, top5 = acc.mean(0).tolist()
|
||||||
return {"top1": top1, "top5": top5, "fitness": top5}
|
return {"top1": top1, "top5": top5, "fitness": top5}
|
||||||
|
|
||||||
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
|
return build_classification_dataloader(path=dataset_path, imgsz=self.args.img_size, batch_size=batch_size)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metric_keys(self):
|
def metric_keys(self):
|
||||||
return ["top1", "top5"]
|
return ["top1", "top5"]
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||||
|
def val(cfg):
|
||||||
|
cfg.data = cfg.data or "imagenette160"
|
||||||
|
cfg.model = cfg.model or "resnet18"
|
||||||
|
validator = ClassificationValidator(args=cfg)
|
||||||
|
validator(model=cfg.model)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
val()
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
from ultralytics.yolo.v8.segment.train import SegmentationTrainer, train
|
from ultralytics.yolo.v8.segment.train import SegmentationTrainer, train
|
||||||
from ultralytics.yolo.v8.segment.val import SegmentationValidator
|
from ultralytics.yolo.v8.segment.val import SegmentationValidator, val
|
||||||
|
@ -33,6 +33,8 @@ class SegmentationTrainer(BaseTrainer):
|
|||||||
anchors=self.args.get("anchors"))
|
anchors=self.args.get("anchors"))
|
||||||
if weights:
|
if weights:
|
||||||
model.load(weights)
|
model.load(weights)
|
||||||
|
for _, v in model.named_parameters():
|
||||||
|
v.requires_grad = True # train all layers
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def set_model_attributes(self):
|
def set_model_attributes(self):
|
||||||
@ -257,7 +259,7 @@ class SegmentationTrainer(BaseTrainer):
|
|||||||
|
|
||||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||||
def train(cfg):
|
def train(cfg):
|
||||||
cfg.model = v8.ROOT / "models/yolov5n-seg.yaml"
|
cfg.model = cfg.model or "models/yolov5n-seg.yaml"
|
||||||
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
|
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
|
||||||
trainer = SegmentationTrainer(cfg)
|
trainer = SegmentationTrainer(cfg)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import hydra
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from ultralytics.yolo.data import build_dataloader
|
||||||
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||||
from ultralytics.yolo.engine.validator import BaseValidator
|
from ultralytics.yolo.engine.validator import BaseValidator
|
||||||
from ultralytics.yolo.utils import ops
|
from ultralytics.yolo.utils import ops
|
||||||
from ultralytics.yolo.utils.checks import check_file, check_requirements
|
from ultralytics.yolo.utils.checks import check_file, check_requirements
|
||||||
@ -16,7 +19,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
|
|||||||
|
|
||||||
class SegmentationValidator(BaseValidator):
|
class SegmentationValidator(BaseValidator):
|
||||||
|
|
||||||
def __init__(self, dataloader, save_dir=None, pbar=None, logger=None, args=None):
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||||
if self.args.save_json:
|
if self.args.save_json:
|
||||||
check_requirements(['pycocotools'])
|
check_requirements(['pycocotools'])
|
||||||
@ -43,14 +46,17 @@ class SegmentationValidator(BaseValidator):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
def init_metrics(self, model):
|
def init_metrics(self, model):
|
||||||
|
if self.training:
|
||||||
head = de_parallel(model).model[-1]
|
head = de_parallel(model).model[-1]
|
||||||
if self.data_dict:
|
else:
|
||||||
self.is_coco = isinstance(self.data_dict.get('val'),
|
head = de_parallel(model).model.model[-1]
|
||||||
str) and self.data_dict['val'].endswith(f'coco{os.sep}val2017.txt')
|
|
||||||
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
|
||||||
|
|
||||||
|
if self.data:
|
||||||
|
self.is_coco = isinstance(self.data.get('val'),
|
||||||
|
str) and self.data['val'].endswith(f'coco{os.sep}val2017.txt')
|
||||||
|
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
|
||||||
|
self.nm = head.nm if hasattr(head, "nm") else 32
|
||||||
self.nc = head.nc
|
self.nc = head.nc
|
||||||
self.nm = head.nm
|
|
||||||
self.names = model.names
|
self.names = model.names
|
||||||
if isinstance(self.names, (list, tuple)): # old format
|
if isinstance(self.names, (list, tuple)): # old format
|
||||||
self.names = dict(enumerate(self.names))
|
self.names = dict(enumerate(self.names))
|
||||||
@ -206,6 +212,12 @@ class SegmentationValidator(BaseValidator):
|
|||||||
correct[matches[:, 1].astype(int), i] = True
|
correct[matches[:, 1].astype(int), i] = True
|
||||||
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
|
return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
|
||||||
|
|
||||||
|
def get_dataloader(self, dataset_path, batch_size):
|
||||||
|
# TODO: manage splits differently
|
||||||
|
# calculate stride - check if model is initialized
|
||||||
|
gs = max(int(de_parallel(self.model).stride if self.model else 0), 32)
|
||||||
|
return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, mode="val")[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metric_keys(self):
|
def metric_keys(self):
|
||||||
return [
|
return [
|
||||||
@ -243,3 +255,14 @@ class SegmentationValidator(BaseValidator):
|
|||||||
plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf,
|
plot_images_and_masks(images, batch_idx, cls, bboxes, plot_masks, paths, conf,
|
||||||
self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred
|
self.save_dir / f'val_batch{ni}_pred.jpg', self.names) # pred
|
||||||
self.plot_masks.clear()
|
self.plot_masks.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||||
|
def val(cfg):
|
||||||
|
cfg.data = cfg.data or "coco128-seg.yaml"
|
||||||
|
validator = SegmentationValidator(args=cfg)
|
||||||
|
validator(model=cfg.model)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
val()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user