mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 05:55:51 +08:00
Improvements (#142)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
dcd8ef68e6
commit
55bdca6768
@ -73,7 +73,7 @@ from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
|||||||
from ultralytics.yolo.data.utils import check_dataset
|
from ultralytics.yolo.data.utils import check_dataset
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, callbacks, colorstr, get_default_args, yaml_save
|
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, callbacks, colorstr, get_default_args, yaml_save
|
||||||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
||||||
from ultralytics.yolo.utils.files import file_size, increment_path
|
from ultralytics.yolo.utils.files import file_size
|
||||||
from ultralytics.yolo.utils.ops import Profile
|
from ultralytics.yolo.utils.ops import Profile
|
||||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, select_device, smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, select_device, smart_inference_mode
|
||||||
|
|
||||||
@ -138,10 +138,6 @@ class Exporter:
|
|||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
self.args = get_config(config, overrides)
|
self.args = get_config(config, overrides)
|
||||||
project = self.args.project or f"runs/{self.args.task}"
|
|
||||||
name = self.args.name or "exp" # hardcode mode as export doesn't require it
|
|
||||||
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
|
||||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
|
@ -35,7 +35,7 @@ from ultralytics.nn.autobackend import AutoBackend
|
|||||||
from ultralytics.yolo.configs import get_config
|
from ultralytics.yolo.configs import get_config
|
||||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
|
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
|
||||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, callbacks, colorstr, ops
|
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
||||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
||||||
from ultralytics.yolo.utils.files import increment_path
|
from ultralytics.yolo.utils.files import increment_path
|
||||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||||
@ -73,7 +73,7 @@ class BasePredictor:
|
|||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
self.args = get_config(config, overrides)
|
self.args = get_config(config, overrides)
|
||||||
project = self.args.project or f"runs/{self.args.task}"
|
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||||
name = self.args.name or f"{self.args.mode}"
|
name = self.args.name or f"{self.args.mode}"
|
||||||
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
|
||||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -25,7 +25,8 @@ import ultralytics.yolo.utils as utils
|
|||||||
from ultralytics import __version__
|
from ultralytics import __version__
|
||||||
from ultralytics.yolo.configs import get_config
|
from ultralytics.yolo.configs import get_config
|
||||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT, callbacks, colorstr, yaml_save
|
from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
||||||
|
yaml_save)
|
||||||
from ultralytics.yolo.utils.checks import check_file, print_args
|
from ultralytics.yolo.utils.checks import check_file, print_args
|
||||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||||
from ultralytics.yolo.utils.files import get_latest_run, increment_path
|
from ultralytics.yolo.utils.files import get_latest_run, increment_path
|
||||||
@ -88,7 +89,7 @@ class BaseTrainer:
|
|||||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||||
|
|
||||||
# Dirs
|
# Dirs
|
||||||
project = self.args.project or f"runs/{self.args.task}"
|
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||||
name = self.args.name or f"{self.args.mode}"
|
name = self.args.name or f"{self.args.mode}"
|
||||||
self.save_dir = Path(
|
self.save_dir = Path(
|
||||||
self.args.get(
|
self.args.get(
|
||||||
|
@ -8,7 +8,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from ultralytics.nn.autobackend import AutoBackend
|
from ultralytics.nn.autobackend import AutoBackend
|
||||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT, callbacks
|
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks
|
||||||
from ultralytics.yolo.utils.checks import check_imgsz
|
from ultralytics.yolo.utils.checks import check_imgsz
|
||||||
from ultralytics.yolo.utils.files import increment_path
|
from ultralytics.yolo.utils.files import increment_path
|
||||||
from ultralytics.yolo.utils.ops import Profile
|
from ultralytics.yolo.utils.ops import Profile
|
||||||
@ -59,7 +59,7 @@ class BaseValidator:
|
|||||||
self.speed = None
|
self.speed = None
|
||||||
self.jdict = None
|
self.jdict = None
|
||||||
|
|
||||||
project = self.args.project or f"runs/{self.args.task}"
|
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||||
name = self.args.name or f"{self.args.mode}"
|
name = self.args.name or f"{self.args.mode}"
|
||||||
self.save_dir = save_dir or increment_path(Path(project) / name,
|
self.save_dir = save_dir or increment_path(Path(project) / name,
|
||||||
exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
|
exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
|
||||||
|
@ -18,7 +18,6 @@ FILE = Path(__file__).resolve()
|
|||||||
ROOT = FILE.parents[2] # YOLO
|
ROOT = FILE.parents[2] # YOLO
|
||||||
DEFAULT_CONFIG = ROOT / "yolo/configs/default.yaml"
|
DEFAULT_CONFIG = ROOT / "yolo/configs/default.yaml"
|
||||||
RANK = int(os.getenv('RANK', -1))
|
RANK = int(os.getenv('RANK', -1))
|
||||||
DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory
|
|
||||||
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
||||||
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
||||||
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
|
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
|
||||||
@ -119,6 +118,41 @@ def is_docker() -> bool:
|
|||||||
return 'docker' in f.read()
|
return 'docker' in f.read()
|
||||||
|
|
||||||
|
|
||||||
|
def is_git_directory() -> bool:
|
||||||
|
"""
|
||||||
|
Check if the current working directory is inside a git repository.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the current working directory is inside a git repository, False otherwise.
|
||||||
|
"""
|
||||||
|
from git import Repo
|
||||||
|
try:
|
||||||
|
# Check if the current working directory is a git repository
|
||||||
|
Repo(search_parent_directories=True)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_pip_package(filepath: str = __name__) -> bool:
|
||||||
|
"""
|
||||||
|
Determines if the file at the given filepath is part of a pip package.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepath (str): The filepath to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the file is part of a pip package, False otherwise.
|
||||||
|
"""
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
# Get the spec for the module
|
||||||
|
spec = importlib.util.find_spec(filepath)
|
||||||
|
|
||||||
|
# Return whether the spec is not None and the origin is not None (indicating it is a package)
|
||||||
|
return spec is not None and spec.origin is not None
|
||||||
|
|
||||||
|
|
||||||
def is_dir_writeable(dir_path: str) -> bool:
|
def is_dir_writeable(dir_path: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if a directory is writeable.
|
Check if a directory is writeable.
|
||||||
@ -305,10 +339,11 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'):
|
|||||||
"""
|
"""
|
||||||
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
||||||
|
|
||||||
|
git_install = not is_pip_package()
|
||||||
defaults = {
|
defaults = {
|
||||||
'datasets_dir': None, # default datasets directory. If None, current working directory is used.
|
'datasets_dir': str(ROOT / 'datasets') if git_install else 'datasets', # default datasets directory.
|
||||||
'weights_dir': None, # default weights directory. If None, current working directory is used.
|
'weights_dir': str(ROOT / 'weights') if git_install else 'weights', # default weights directory.
|
||||||
'runs_dir': None, # default runs directory. If None, current working directory is used.
|
'runs_dir': str(ROOT / 'runs') if git_install else 'runs', # default runs directory.
|
||||||
'sync': True, # sync analytics to help with YOLO development
|
'sync': True, # sync analytics to help with YOLO development
|
||||||
'uuid': uuid.getnode(), # device UUID to align analytics
|
'uuid': uuid.getnode(), # device UUID to align analytics
|
||||||
'yaml_file': str(file)} # setting YAML file path
|
'yaml_file': str(file)} # setting YAML file path
|
||||||
@ -336,6 +371,7 @@ if platform.system() == 'Windows':
|
|||||||
|
|
||||||
# Check first-install steps
|
# Check first-install steps
|
||||||
SETTINGS = get_settings()
|
SETTINGS = get_settings()
|
||||||
|
DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory
|
||||||
|
|
||||||
|
|
||||||
def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
|
def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import glob
|
import glob
|
||||||
import inspect
|
import inspect
|
||||||
|
import math
|
||||||
import platform
|
import platform
|
||||||
import urllib
|
import urllib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -13,71 +14,141 @@ import torch
|
|||||||
|
|
||||||
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis,
|
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis,
|
||||||
is_docker, is_jupyter_notebook)
|
is_docker, is_jupyter_notebook)
|
||||||
from ultralytics.yolo.utils.ops import make_divisible
|
|
||||||
|
|
||||||
|
|
||||||
def is_ascii(s=''):
|
def is_ascii(s) -> bool:
|
||||||
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
|
"""
|
||||||
s = str(s) # convert list, tuple, None, etc. to str
|
Check if a string is composed of only ASCII characters.
|
||||||
return len(s.encode().decode('ascii', 'ignore')) == len(s)
|
|
||||||
|
Args:
|
||||||
|
s (str): String to be checked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the string is composed only of ASCII characters, False otherwise.
|
||||||
|
"""
|
||||||
|
# Convert list, tuple, None, etc. to string
|
||||||
|
s = str(s)
|
||||||
|
|
||||||
|
# Check if the string is composed of only ASCII characters
|
||||||
|
return all(ord(c) < 128 for c in s)
|
||||||
|
|
||||||
|
|
||||||
def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
|
def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
|
||||||
# Verify image size is a multiple of stride s in each dimension
|
"""
|
||||||
|
Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
|
||||||
|
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
imgsz (int or List[int]): Image size.
|
||||||
|
stride (int): Stride value.
|
||||||
|
min_dim (int): Minimum number of dimensions.
|
||||||
|
floor (int): Minimum allowed value for image size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: Updated image size.
|
||||||
|
"""
|
||||||
|
# Convert stride to integer if it is a tensor
|
||||||
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
|
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
|
||||||
if isinstance(imgsz, int): # integer i.e. imgsz=640
|
|
||||||
sz = max(make_divisible(imgsz, stride), floor)
|
# Convert image size to list if it is an integer
|
||||||
else: # list i.e. imgsz=[640, 480]
|
if isinstance(imgsz, int):
|
||||||
imgsz = list(imgsz) # convert to list if tuple
|
imgsz = [imgsz]
|
||||||
sz = [max(make_divisible(x, stride), floor) for x in imgsz]
|
|
||||||
|
# Make image size a multiple of the stride
|
||||||
|
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
|
||||||
|
|
||||||
|
# Print warning message if image size was updated
|
||||||
if sz != imgsz:
|
if sz != imgsz:
|
||||||
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {stride}, updating to {sz}')
|
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {stride}, updating to {sz}')
|
||||||
|
|
||||||
# Check dims
|
# Add missing dimensions if necessary
|
||||||
if min_dim == 2:
|
if min_dim == 2 and len(sz) == 1:
|
||||||
if isinstance(imgsz, int):
|
|
||||||
sz = [sz, sz]
|
|
||||||
elif len(sz) == 1:
|
|
||||||
sz = [sz[0], sz[0]]
|
sz = [sz[0], sz[0]]
|
||||||
|
|
||||||
return sz
|
return sz
|
||||||
|
|
||||||
|
|
||||||
def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False):
|
def check_version(current: str = "0.0.0",
|
||||||
# Check version vs. required version
|
minimum: str = "0.0.0",
|
||||||
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
name: str = "version ",
|
||||||
|
pinned: bool = False,
|
||||||
|
hard: bool = False,
|
||||||
|
verbose: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
Check current version against the required minimum version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current (str): Current version.
|
||||||
|
minimum (str): Required minimum version.
|
||||||
|
name (str): Name to be used in warning message.
|
||||||
|
pinned (bool): If True, versions must match exactly. If False, minimum version must be satisfied.
|
||||||
|
hard (bool): If True, raise an AssertionError if the minimum version is not met.
|
||||||
|
verbose (bool): If True, print warning message if minimum version is not met.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if minimum version is met, False otherwise.
|
||||||
|
"""
|
||||||
|
from pkg_resources import parse_version
|
||||||
|
current, minimum = (parse_version(x) for x in (current, minimum))
|
||||||
result = (current == minimum) if pinned else (current >= minimum) # bool
|
result = (current == minimum) if pinned else (current >= minimum) # bool
|
||||||
s = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed" # string
|
warning_message = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed"
|
||||||
if hard:
|
if hard:
|
||||||
assert result, emojis(s) # assert min requirements met
|
assert result, emojis(warning_message) # assert min requirements met
|
||||||
if verbose and not result:
|
if verbose and not result:
|
||||||
LOGGER.warning(s)
|
LOGGER.warning(warning_message)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def check_font(font=FONT, progress=False):
|
def check_font(font: str = FONT, progress: bool = False) -> None:
|
||||||
# Download font to CONFIG_DIR if necessary
|
"""
|
||||||
|
Download font file to the user's configuration directory if it does not already exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
font (str): Path to font file.
|
||||||
|
progress (bool): If True, display a progress bar during the download.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
font = Path(font)
|
font = Path(font)
|
||||||
|
|
||||||
|
# Destination path for the font file
|
||||||
file = USER_CONFIG_DIR / font.name
|
file = USER_CONFIG_DIR / font.name
|
||||||
|
|
||||||
|
# Check if font file exists at the source or destination path
|
||||||
if not font.exists() and not file.exists():
|
if not font.exists() and not file.exists():
|
||||||
|
# Download font file
|
||||||
url = f'https://ultralytics.com/assets/{font.name}'
|
url = f'https://ultralytics.com/assets/{font.name}'
|
||||||
LOGGER.info(f'Downloading {url} to {file}...')
|
LOGGER.info(f'Downloading {url} to {file}...')
|
||||||
torch.hub.download_url_to_file(url, str(file), progress=progress)
|
torch.hub.download_url_to_file(url, str(file), progress=progress)
|
||||||
|
|
||||||
|
|
||||||
def check_online():
|
def check_online() -> bool:
|
||||||
# Check internet connectivity
|
"""
|
||||||
|
Check internet connectivity by attempting to connect to a known online host.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if connection is successful, False otherwise.
|
||||||
|
"""
|
||||||
import socket
|
import socket
|
||||||
try:
|
try:
|
||||||
socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
|
# Check host accessibility by attempting to establish a connection
|
||||||
|
socket.create_connection(("1.1.1.1", 443), timeout=5)
|
||||||
return True
|
return True
|
||||||
except OSError:
|
except OSError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def check_python(minimum='3.7.0'):
|
def check_python(minimum: str = '3.7.0') -> bool:
|
||||||
# Check current python version vs. required python version
|
"""
|
||||||
|
Check current python version against the required minimum version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
minimum (str): Required minimum version of python.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
check_version(platform.python_version(), minimum, name='Python ', hard=True)
|
check_version(platform.python_version(), minimum, name='Python ', hard=True)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user