mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Add new get_save_dir()
function (#4602)
This commit is contained in:
parent
1121ef2409
commit
23b4f697c9
@ -8,9 +8,9 @@ from pathlib import Path
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, SETTINGS, SETTINGS_YAML,
|
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS,
|
||||||
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, yaml_load,
|
SETTINGS_YAML, IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn,
|
||||||
yaml_print)
|
yaml_load, yaml_print)
|
||||||
|
|
||||||
# Define valid tasks and modes
|
# Define valid tasks and modes
|
||||||
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
||||||
@ -146,8 +146,23 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
|||||||
return IterableSimpleNamespace(**cfg)
|
return IterableSimpleNamespace(**cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def get_save_dir(args):
|
||||||
|
"""Return save_dir as created from train/val/predict arguments."""
|
||||||
|
|
||||||
|
if getattr(args, 'save_dir', None):
|
||||||
|
save_dir = args.save_dir
|
||||||
|
else:
|
||||||
|
from ultralytics.utils.files import increment_path
|
||||||
|
|
||||||
|
project = args.project or Path(SETTINGS['runs_dir']) / args.task
|
||||||
|
name = args.name or f'{args.mode}'
|
||||||
|
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
|
||||||
|
|
||||||
|
return Path(save_dir)
|
||||||
|
|
||||||
|
|
||||||
def _handle_deprecation(custom):
|
def _handle_deprecation(custom):
|
||||||
"""Hardcoded function to handle deprecated config keys"""
|
"""Hardcoded function to handle deprecated config keys."""
|
||||||
|
|
||||||
for key in custom.copy().keys():
|
for key in custom.copy().keys():
|
||||||
if key == 'hide_labels':
|
if key == 'hide_labels':
|
||||||
@ -171,6 +186,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
|
|||||||
Args:
|
Args:
|
||||||
custom (dict): a dictionary of custom configuration options
|
custom (dict): a dictionary of custom configuration options
|
||||||
base (dict): a dictionary of base configuration options
|
base (dict): a dictionary of base configuration options
|
||||||
|
e (Error, optional): An optional error that is passed by the calling function.
|
||||||
"""
|
"""
|
||||||
custom = _handle_deprecation(custom)
|
custom = _handle_deprecation(custom)
|
||||||
base_keys, custom_keys = (set(x.keys()) for x in (base, custom))
|
base_keys, custom_keys = (set(x.keys()) for x in (base, custom))
|
||||||
|
@ -5,7 +5,7 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from ultralytics.cfg import get_cfg
|
from ultralytics.cfg import get_cfg, get_save_dir
|
||||||
from ultralytics.engine.exporter import Exporter
|
from ultralytics.engine.exporter import Exporter
|
||||||
from ultralytics.hub.utils import HUB_WEB_ROOT
|
from ultralytics.hub.utils import HUB_WEB_ROOT
|
||||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||||
@ -239,7 +239,7 @@ class Model:
|
|||||||
else: # only update args if predictor is already setup
|
else: # only update args if predictor is already setup
|
||||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||||
if 'project' in overrides or 'name' in overrides:
|
if 'project' in overrides or 'name' in overrides:
|
||||||
self.predictor.save_dir = self.predictor.get_save_dir()
|
self.predictor.save_dir = get_save_dir(self.predictor.args)
|
||||||
# Set prompts for SAM/FastSAM
|
# Set prompts for SAM/FastSAM
|
||||||
if len and hasattr(self.predictor, 'set_prompts'):
|
if len and hasattr(self.predictor, 'set_prompts'):
|
||||||
self.predictor.set_prompts(prompts)
|
self.predictor.set_prompts(prompts)
|
||||||
|
@ -34,11 +34,11 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics.cfg import get_cfg
|
from ultralytics.cfg import get_cfg, get_save_dir
|
||||||
from ultralytics.data import load_inference_source
|
from ultralytics.data import load_inference_source
|
||||||
from ultralytics.data.augment import LetterBox, classify_transforms
|
from ultralytics.data.augment import LetterBox, classify_transforms
|
||||||
from ultralytics.nn.autobackend import AutoBackend
|
from ultralytics.nn.autobackend import AutoBackend
|
||||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, SETTINGS, WINDOWS, callbacks, colorstr, ops
|
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
|
||||||
from ultralytics.utils.checks import check_imgsz, check_imshow
|
from ultralytics.utils.checks import check_imgsz, check_imshow
|
||||||
from ultralytics.utils.files import increment_path
|
from ultralytics.utils.files import increment_path
|
||||||
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
||||||
@ -84,7 +84,7 @@ class BasePredictor:
|
|||||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.args = get_cfg(cfg, overrides)
|
self.args = get_cfg(cfg, overrides)
|
||||||
self.save_dir = self.get_save_dir()
|
self.save_dir = get_save_dir(self.args)
|
||||||
if self.args.conf is None:
|
if self.args.conf is None:
|
||||||
self.args.conf = 0.25 # default conf=0.25
|
self.args.conf = 0.25 # default conf=0.25
|
||||||
self.done_warmup = False
|
self.done_warmup = False
|
||||||
@ -108,11 +108,6 @@ class BasePredictor:
|
|||||||
self.txt_path = None
|
self.txt_path = None
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
def get_save_dir(self):
|
|
||||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
|
||||||
name = self.args.name or f'{self.args.mode}'
|
|
||||||
return increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
|
||||||
|
|
||||||
def preprocess(self, im):
|
def preprocess(self, im):
|
||||||
"""Prepares input image before inference.
|
"""Prepares input image before inference.
|
||||||
|
|
||||||
|
@ -323,14 +323,10 @@ class Results(SimpleClass):
|
|||||||
if self.probs is not None:
|
if self.probs is not None:
|
||||||
LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
|
LOGGER.warning('WARNING ⚠️ Classify task do not support `save_crop`.')
|
||||||
return
|
return
|
||||||
if isinstance(save_dir, str):
|
|
||||||
save_dir = Path(save_dir)
|
|
||||||
if isinstance(file_name, str):
|
|
||||||
file_name = Path(file_name)
|
|
||||||
for d in self.boxes:
|
for d in self.boxes:
|
||||||
save_one_box(d.xyxy,
|
save_one_box(d.xyxy,
|
||||||
self.orig_img.copy(),
|
self.orig_img.copy(),
|
||||||
file=save_dir / self.names[int(d.cls)] / f'{file_name.stem}.jpg',
|
file=Path(save_dir) / self.names[int(d.cls)] / f'{Path(file_name).stem}.jpg',
|
||||||
BGR=True)
|
BGR=True)
|
||||||
|
|
||||||
def tojson(self, normalize=False):
|
def tojson(self, normalize=False):
|
||||||
|
@ -23,15 +23,15 @@ from torch.cuda import amp
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ultralytics.cfg import get_cfg
|
from ultralytics.cfg import get_cfg, get_save_dir
|
||||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||||
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||||
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, clean_url,
|
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM_BAR_FORMAT, __version__, callbacks, clean_url, colorstr,
|
||||||
colorstr, emojis, yaml_save)
|
emojis, yaml_save)
|
||||||
from ultralytics.utils.autobatch import check_train_batch_size
|
from ultralytics.utils.autobatch import check_train_batch_size
|
||||||
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
|
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
|
||||||
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
||||||
from ultralytics.utils.files import get_latest_run, increment_path
|
from ultralytics.utils.files import get_latest_run
|
||||||
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
|
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
|
||||||
strip_optimizer)
|
strip_optimizer)
|
||||||
|
|
||||||
@ -91,13 +91,7 @@ class BaseTrainer:
|
|||||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||||
|
|
||||||
# Dirs
|
# Dirs
|
||||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
self.save_dir = get_save_dir(self.args)
|
||||||
name = self.args.name or f'{self.args.mode}'
|
|
||||||
if hasattr(self.args, 'save_dir'):
|
|
||||||
self.save_dir = Path(self.args.save_dir)
|
|
||||||
else:
|
|
||||||
self.save_dir = Path(
|
|
||||||
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))
|
|
||||||
self.wdir = self.save_dir / 'weights' # weights dir
|
self.wdir = self.save_dir / 'weights' # weights dir
|
||||||
if RANK in (-1, 0):
|
if RANK in (-1, 0):
|
||||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||||
|
@ -26,12 +26,11 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ultralytics.cfg import get_cfg
|
from ultralytics.cfg import get_cfg, get_save_dir
|
||||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||||
from ultralytics.nn.autobackend import AutoBackend
|
from ultralytics.nn.autobackend import AutoBackend
|
||||||
from ultralytics.utils import LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
|
from ultralytics.utils import LOGGER, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
|
||||||
from ultralytics.utils.checks import check_imgsz
|
from ultralytics.utils.checks import check_imgsz
|
||||||
from ultralytics.utils.files import increment_path
|
|
||||||
from ultralytics.utils.ops import Profile
|
from ultralytics.utils.ops import Profile
|
||||||
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
|
from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode
|
||||||
|
|
||||||
@ -71,7 +70,7 @@ class BaseValidator:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
|
||||||
save_dir (Path): Directory to save results.
|
save_dir (Path, optional): Directory to save results.
|
||||||
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
pbar (tqdm.tqdm): Progress bar for displaying progress.
|
||||||
args (SimpleNamespace): Configuration for the validator.
|
args (SimpleNamespace): Configuration for the validator.
|
||||||
_callbacks (dict): Dictionary to store various callback functions.
|
_callbacks (dict): Dictionary to store various callback functions.
|
||||||
@ -93,12 +92,8 @@ class BaseValidator:
|
|||||||
self.jdict = None
|
self.jdict = None
|
||||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||||
|
|
||||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
self.save_dir = save_dir or get_save_dir(self.args)
|
||||||
name = self.args.name or f'{self.args.mode}'
|
|
||||||
self.save_dir = save_dir or increment_path(Path(project) / name,
|
|
||||||
exist_ok=self.args.exist_ok if RANK in (-1, 0) else True)
|
|
||||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if self.args.conf is None:
|
if self.args.conf is None:
|
||||||
self.args.conf = 0.001 # default conf=0.001
|
self.args.conf = 0.001 # default conf=0.001
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user