mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Add new get_save_dir()
function (#4602)
This commit is contained in:
parent
1121ef2409
commit
23b4f697c9
@ -8,9 +8,9 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, SETTINGS, SETTINGS_YAML,
|
||||
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, yaml_load,
|
||||
yaml_print)
|
||||
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS,
|
||||
SETTINGS_YAML, IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn,
|
||||
yaml_load, yaml_print)
|
||||
|
||||
# Define valid tasks and modes
|
||||
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
||||
@ -146,8 +146,23 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
||||
return IterableSimpleNamespace(**cfg)
|
||||
|
||||
|
||||
def get_save_dir(args):
|
||||
"""Return save_dir as created from train/val/predict arguments."""
|
||||
|
||||
if getattr(args, 'save_dir', None):
|
||||
save_dir = args.save_dir
|
||||
else:
|
||||
from ultralytics.utils.files import increment_path
|
||||
|
||||
project = args.project or Path(SETTINGS['runs_dir']) / args.task
|
||||
name = args.name or f'{args.mode}'
|
||||
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
|
||||
|
||||
return Path(save_dir)
|
||||
|
||||
|
||||
def _handle_deprecation(custom):
|
||||
"""Hardcoded function to handle deprecated config keys"""
|
||||
"""Hardcoded function to handle deprecated config keys."""
|
||||
|
||||
for key in custom.copy().keys():
|
||||
if key == 'hide_labels':
|
||||
@ -171,6 +186,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
|
||||
Args:
|
||||
custom (dict): a dictionary of custom configuration options
|
||||
base (dict): a dictionary of base configuration options
|
||||
e (Error, optional): An optional error that is passed by the calling function.
|
||||
"""
|
||||
custom = _handle_deprecation(custom)
|
||||
base_keys, custom_keys = (set(x.keys()) for x in (base, custom))
|
||||
|
@ -5,7 +5,7 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.cfg import get_cfg, get_save_dir
|
||||
from ultralytics.engine.exporter import Exporter
|
||||
from ultralytics.hub.utils import HUB_WEB_ROOT
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||
@ -239,7 +239,7 @@ class Model:
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||
if 'project' in overrides or 'name' in overrides:
|
||||
self.predictor.save_dir = self.predictor.get_save_dir()
|
||||
self.predictor.save_dir = get_save_dir(self.predictor.args)
|
||||
# Set prompts for SAM/FastSAM
|
||||
if len and hasattr(self.predictor, 'set_prompts'):
|
||||
self.predictor.set_prompts(prompts)
|
||||
|
@ -34,11 +34,11 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.cfg import get_cfg, get_save_dir
|
||||
from ultralytics.data import load_inference_source
|
||||
from ultralytics.data.augment import LetterBox, classify_transforms
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, SETTINGS, WINDOWS, callbacks, colorstr, ops
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
|
||||
from ultralytics.utils.checks import check_imgsz, check_imshow
|
||||
from ultralytics.utils.files import increment_path
|
||||
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
||||
@ -84,7 +84,7 @@ class BasePredictor:
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.save_dir = self.get_save_dir()
|
||||
self.save_dir = get_save_dir(self.args)
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.25 # default conf=0.25
|
||||
self.done_warmup = False
|
||||
@ -108,11 +108,6 @@ class BasePredictor:
|
||||
self.txt_path = None
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def get_save_dir(self):
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
return increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||
|
||||
def preprocess(self, im):
|
||||
"""Prepares input image before inference.
|
||||
|
||||
|
@ -323,14 +323,10 @@ class Results(SimpleClass):
|
||||
if self.probs is not None:
|
||||
LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
|
||||
return
|
||||
if isinstance(save_dir, str):
|
||||
save_dir = Path(save_dir)
|
||||
if isinstance(file_name, str):
|
||||
file_name = Path(file_name)
|
||||
for d in self.boxes:
|
||||
save_one_box(d.xyxy,
|
||||
self.orig_img.copy(),
|
||||
file=save_dir / self.names[int(d.cls)] / f'{file_name.stem}.jpg',
|
||||
file=Path(save_dir) / self.names[int(d.cls)] / f'{Path(file_name).stem}.jpg',
|
||||
BGR=True)
|
||||
|
||||
def tojson(self, normalize=False):
|
||||
|
@ -23,15 +23,15 @@ from torch.cuda import amp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.cfg import get_cfg, get_save_dir
|
||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, clean_url,
|
||||
colorstr, emojis, yaml_save)
|
||||
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM_BAR_FORMAT, __version__, callbacks, clean_url, colorstr,
|
||||
emojis, yaml_save)
|
||||
from ultralytics.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
|
||||
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
from ultralytics.utils.files import get_latest_run, increment_path
|
||||
from ultralytics.utils.files import get_latest_run
|
||||
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
|
||||
strip_optimizer)
|
||||
|
||||
@ -91,13 +91,7 @@ class BaseTrainer:
|
||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||
|
||||
# Dirs
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
if hasattr(self.args, 'save_dir'):
|
||||
self.save_dir = Path(self.args.save_dir)
|
||||
else:
|
||||
self.save_dir = Path(
|
||||
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))
|
||||
self.save_dir = get_save_dir(self.args)
|
||||
self.wdir = self.save_dir / 'weights' # weights dir
|
||||
if RANK in (-1, 0):
|
||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
|
@ -26,12 +26,11 @@ import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.cfg import get_cfg
|
||||
from ultralytics.cfg import get_cfg, get_save_dir
|
||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.utils import LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
|
||||
from ultralytics.utils import LOGGER, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
|
||||
from ultralytics.utils.checks import check_imgsz
|
||||
from ultralytics.utils.files import increment_path
|
||||
from ultralytics.utils.ops import Profile
|
||||
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
|
||||
|
||||
@ -71,7 +70,7 @@ class BaseValidator:
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
||||
save_dir (Path): Directory to save results.
|
||||
save_dir (Path, optional): Directory to save results.
|
||||
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
_callbacks (dict): Dictionary to store various callback functions.
|
||||
@ -93,12 +92,8 @@ class BaseValidator:
|
||||
self.jdict = None
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
name = self.args.name or f'{self.args.mode}'
|
||||
self.save_dir = save_dir or increment_path(Path(project) / name,
|
||||
exist_ok=self.args.exist_ok if RANK in (-1, 0) else True)
|
||||
self.save_dir = save_dir or get_save_dir(self.args)
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.001 # default conf=0.001
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user