mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 13:34:23 +08:00
Update Validator to use model
argument (#4480)
This commit is contained in:
parent
615ddc9d97
commit
b2f279ffdd
@ -82,7 +82,7 @@ def cfg2dict(cfg):
|
||||
Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | SimpleNamespace): Configuration object to be converted to a dictionary.
|
||||
cfg (str | Path | dict | SimpleNamespace): Configuration object to be converted to a dictionary.
|
||||
|
||||
Returns:
|
||||
cfg (dict): Configuration object in dictionary format.
|
||||
@ -110,6 +110,7 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
||||
# Merge overrides
|
||||
if overrides:
|
||||
overrides = cfg2dict(overrides)
|
||||
overrides.pop('save_dir', None) # special override keys to ignore
|
||||
check_dict_alignment(cfg, overrides)
|
||||
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)
|
||||
|
||||
|
@ -29,7 +29,7 @@ from tqdm import tqdm
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
|
||||
from ultralytics.utils import LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
|
||||
from ultralytics.utils.checks import check_imgsz
|
||||
from ultralytics.utils.files import increment_path
|
||||
from ultralytics.utils.ops import Profile
|
||||
@ -43,9 +43,9 @@ class BaseValidator:
|
||||
A base class for creating validators.
|
||||
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
dataloader (DataLoader): Dataloader to use for validation.
|
||||
pbar (tqdm): Progress bar to update during validation.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
model (nn.Module): Model to validate.
|
||||
data (dict): Data dictionary.
|
||||
device (torch.device): Device to use for validation.
|
||||
@ -76,9 +76,9 @@ class BaseValidator:
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
_callbacks (dict): Dictionary to store various callback functions.
|
||||
"""
|
||||
self.args = get_cfg(overrides=args)
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
self.args = args or get_cfg(DEFAULT_CFG)
|
||||
self.model = None
|
||||
self.data = None
|
||||
self.device = None
|
||||
@ -126,8 +126,7 @@ class BaseValidator:
|
||||
else:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
self.run_callbacks('on_val_start')
|
||||
assert model is not None, 'Either trainer or model is needed for validation'
|
||||
model = AutoBackend(model,
|
||||
model = AutoBackend(model or self.args.model,
|
||||
device=select_device(self.args.device, self.args.batch),
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
|
@ -14,7 +14,7 @@ from ultralytics.utils import colorstr, ops
|
||||
__all__ = 'RTDETRValidator', # tuple or list
|
||||
|
||||
|
||||
# TODO: Temporarily, RT-DETR does not need padding.
|
||||
# TODO: Temporarily RT-DETR does not need padding.
|
||||
class RTDETRDataset(YOLODataset):
|
||||
|
||||
def __init__(self, *args, data=None, **kwargs):
|
||||
@ -47,7 +47,7 @@ class RTDETRDataset(YOLODataset):
|
||||
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
"""Temporarily, only for evaluation."""
|
||||
"""Temporary, only for evaluation."""
|
||||
if self.augment:
|
||||
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
||||
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
||||
@ -76,12 +76,13 @@ class RTDETRValidator(DetectionValidator):
|
||||
|
||||
args = dict(model='rtdetr-l.pt', data='coco8.yaml')
|
||||
validator = RTDETRValidator(args=args)
|
||||
validator(model=args['model'])
|
||||
validator()
|
||||
```
|
||||
"""
|
||||
|
||||
def build_dataset(self, img_path, mode='val', batch=None):
|
||||
"""Build YOLO Dataset
|
||||
"""
|
||||
Build an RTDETR Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
|
@ -22,7 +22,7 @@ class ClassificationValidator(BaseValidator):
|
||||
|
||||
args = dict(model='yolov8n-cls.pt', data='imagenet10')
|
||||
validator = ClassificationValidator(args=args)
|
||||
validator(model=args['model'])
|
||||
validator()
|
||||
```
|
||||
"""
|
||||
|
||||
|
@ -25,7 +25,7 @@ class DetectionValidator(BaseValidator):
|
||||
|
||||
args = dict(model='yolov8n.pt', data='coco8.yaml')
|
||||
validator = DetectionValidator(args=args)
|
||||
validator(model=args['model'])
|
||||
validator()
|
||||
```
|
||||
"""
|
||||
|
||||
|
@ -22,7 +22,7 @@ class PoseValidator(DetectionValidator):
|
||||
|
||||
args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml')
|
||||
validator = PoseValidator(args=args)
|
||||
validator(model=args['model'])
|
||||
validator()
|
||||
```
|
||||
"""
|
||||
|
||||
|
@ -24,7 +24,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
|
||||
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml')
|
||||
validator = SegmentationValidator(args=args)
|
||||
validator(model=args['model'])
|
||||
validator()
|
||||
```
|
||||
"""
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user