mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Remove GitPython dependency (#568)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3c4de102f6
commit
d6a4ffb778
@ -37,6 +37,11 @@ theme:
|
|||||||
- navigation.footer
|
- navigation.footer
|
||||||
- content.tabs.link # all code tabs change simultaneously
|
- content.tabs.link # all code tabs change simultaneously
|
||||||
|
|
||||||
|
# Version drop-down menu
|
||||||
|
# extra:
|
||||||
|
# version:
|
||||||
|
# provider: mike
|
||||||
|
|
||||||
extra_css:
|
extra_css:
|
||||||
- stylesheets/style.css
|
- stylesheets/style.css
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
# Base ----------------------------------------
|
# Base ----------------------------------------
|
||||||
matplotlib>=3.2.2
|
matplotlib>=3.2.2
|
||||||
numpy>=1.18.5
|
numpy>=1.18.5
|
||||||
opencv-python>=4.1.1
|
opencv-python>=4.6.0
|
||||||
Pillow>=7.1.2
|
Pillow>=7.1.2
|
||||||
PyYAML>=5.3.1
|
PyYAML>=5.3.1
|
||||||
requests>=2.23.0
|
requests>=2.23.0
|
||||||
@ -40,6 +40,3 @@ thop>=0.1.1 # FLOPs computation
|
|||||||
# albumentations>=1.0.3
|
# albumentations>=1.0.3
|
||||||
# pycocotools>=2.0.6 # COCO mAP
|
# pycocotools>=2.0.6 # COCO mAP
|
||||||
# roboflow
|
# roboflow
|
||||||
|
|
||||||
# HUB -----------------------------------------
|
|
||||||
GitPython>=3.1.24
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.0.15"
|
__version__ = "8.0.17"
|
||||||
|
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
from ultralytics.yolo.utils import ops
|
from ultralytics.yolo.utils import ops
|
||||||
|
@ -11,7 +11,7 @@ from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, Bot
|
|||||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||||
GhostBottleneck, GhostConv, Segment)
|
GhostBottleneck, GhostConv, Segment)
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_yaml
|
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
||||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
|
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
|
||||||
model_info, scale_img, time_sync)
|
model_info, scale_img, time_sync)
|
||||||
|
|
||||||
@ -357,7 +357,16 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
|||||||
# Loads a single model weights
|
# Loads a single model weights
|
||||||
from ultralytics.yolo.utils.downloads import attempt_download
|
from ultralytics.yolo.utils.downloads import attempt_download
|
||||||
|
|
||||||
ckpt = torch.load(attempt_download(weight), map_location='cpu') # load
|
weight = attempt_download(weight)
|
||||||
|
try:
|
||||||
|
ckpt = torch.load(weight, map_location='cpu') # load
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
LOGGER.warning(f"WARNING ⚠️ {weight} is deprecated as it requires omegaconf, which is now removed from "
|
||||||
|
"ultralytics requirements.\nAutoInstall will occur now but this feature will be removed for "
|
||||||
|
"omegaconf models in the future.\nPlease train a new model or download updated models "
|
||||||
|
"from https://github.com/ultralytics/assets/releases/tag/v0.0.0")
|
||||||
|
check_requirements('omegaconf')
|
||||||
|
ckpt = torch.load(weight, map_location='cpu') # load
|
||||||
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||||
|
|
||||||
|
@ -6,11 +6,12 @@ import sys
|
|||||||
from difflib import get_close_matches
|
from difflib import get_close_matches
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Dict, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
from ultralytics import __version__, yolo
|
from ultralytics import __version__, yolo
|
||||||
from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, USER_CONFIG_DIR,
|
from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, PREFIX, USER_CONFIG_DIR,
|
||||||
IterableSimpleNamespace, checks, colorstr, yaml_load, yaml_print)
|
IterableSimpleNamespace, colorstr, yaml_load, yaml_print)
|
||||||
|
from ultralytics.yolo.utils.checks import check_yolo
|
||||||
|
|
||||||
CLI_HELP_MSG = \
|
CLI_HELP_MSG = \
|
||||||
"""
|
"""
|
||||||
@ -111,6 +112,33 @@ def check_cfg_mismatch(base: Dict, custom: Dict):
|
|||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
|
|
||||||
|
def merge_equals_args(args: List[str]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Merges arguments around isolated '=' args in a list of strings.
|
||||||
|
The function considers cases where the first argument ends with '=' or the second starts with '=',
|
||||||
|
as well as when the middle one is an equals sign.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (List[str]): A list of strings where each element is an argument.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: A list of strings where the arguments around isolated '=' are merged.
|
||||||
|
"""
|
||||||
|
new_args = []
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if arg == '=' and 0 < i < len(args) - 1:
|
||||||
|
new_args[-1] += f"={args[i + 1]}"
|
||||||
|
del args[i + 1]
|
||||||
|
elif arg.endswith('=') and i < len(args) - 1:
|
||||||
|
new_args.append(f"{arg}{args[i + 1]}")
|
||||||
|
del args[i + 1]
|
||||||
|
elif arg.startswith('=') and i > 0:
|
||||||
|
new_args[-1] += arg
|
||||||
|
else:
|
||||||
|
new_args.append(arg)
|
||||||
|
return new_args
|
||||||
|
|
||||||
|
|
||||||
def argument_error(arg):
|
def argument_error(arg):
|
||||||
return SyntaxError(f"'{arg}' is not a valid YOLO argument.\n{CLI_HELP_MSG}")
|
return SyntaxError(f"'{arg}' is not a valid YOLO argument.\n{CLI_HELP_MSG}")
|
||||||
|
|
||||||
@ -130,7 +158,7 @@ def entrypoint(debug=False):
|
|||||||
It uses the package's default cfg and initializes it using the passed overrides.
|
It uses the package's default cfg and initializes it using the passed overrides.
|
||||||
Then it calls the CLI function with the composed cfg
|
Then it calls the CLI function with the composed cfg
|
||||||
"""
|
"""
|
||||||
args = ['train', 'predict', 'model=yolov8n.pt'] if debug else sys.argv[1:]
|
args = ['train', 'model=yolov8n.pt', 'data=coco128.yaml', 'imgsz=32', 'epochs=1'] if debug else sys.argv[1:]
|
||||||
if not args: # no arguments passed
|
if not args: # no arguments passed
|
||||||
LOGGER.info(CLI_HELP_MSG)
|
LOGGER.info(CLI_HELP_MSG)
|
||||||
return
|
return
|
||||||
@ -139,14 +167,14 @@ def entrypoint(debug=False):
|
|||||||
modes = 'train', 'val', 'predict', 'export'
|
modes = 'train', 'val', 'predict', 'export'
|
||||||
special = {
|
special = {
|
||||||
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
||||||
'checks': checks.check_yolo,
|
'checks': check_yolo,
|
||||||
'version': lambda: LOGGER.info(__version__),
|
'version': lambda: LOGGER.info(__version__),
|
||||||
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
|
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
|
||||||
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
||||||
'copy-cfg': copy_default_config}
|
'copy-cfg': copy_default_config}
|
||||||
|
|
||||||
overrides = {} # basic overrides, i.e. imgsz=320
|
overrides = {} # basic overrides, i.e. imgsz=320
|
||||||
for a in args:
|
for a in merge_equals_args(args): # merge spaces around '=' sign
|
||||||
if '=' in a:
|
if '=' in a:
|
||||||
try:
|
try:
|
||||||
re.sub(r' *= *', '=', a) # remove spaces around equals sign
|
re.sub(r' *= *', '=', a) # remove spaces around equals sign
|
||||||
@ -185,6 +213,13 @@ def entrypoint(debug=False):
|
|||||||
|
|
||||||
cfg = get_cfg(DEFAULT_CFG_DICT, overrides) # create CFG instance
|
cfg = get_cfg(DEFAULT_CFG_DICT, overrides) # create CFG instance
|
||||||
|
|
||||||
|
# Checks error catch
|
||||||
|
if cfg.mode == 'checks':
|
||||||
|
LOGGER.warning(
|
||||||
|
"WARNING ⚠️ 'yolo mode=checks' is deprecated and will be removed in the future. Use 'yolo checks' instead.")
|
||||||
|
check_yolo()
|
||||||
|
return
|
||||||
|
|
||||||
# Mapping from task to module
|
# Mapping from task to module
|
||||||
module = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}.get(cfg.task)
|
module = {"detect": yolo.v8.detect, "segment": yolo.v8.segment, "classify": yolo.v8.classify}.get(cfg.task)
|
||||||
if not module:
|
if not module:
|
||||||
|
@ -231,7 +231,7 @@ def check_dataset_yaml(dataset, autodownload=True):
|
|||||||
if s and autodownload:
|
if s and autodownload:
|
||||||
LOGGER.warning(msg)
|
LOGGER.warning(msg)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(s)
|
raise FileNotFoundError(msg)
|
||||||
t = time.time()
|
t = time.time()
|
||||||
if s.startswith('http') and s.endswith('.zip'): # URL
|
if s.startswith('http') and s.endswith('.zip'): # URL
|
||||||
f = Path(s).name # filename
|
f = Path(s).name # filename
|
||||||
|
@ -508,7 +508,6 @@ class BaseTrainer:
|
|||||||
strip_optimizer(f) # strip optimizers
|
strip_optimizer(f) # strip optimizers
|
||||||
if f is self.best:
|
if f is self.best:
|
||||||
self.console.info(f'\nValidating {f}...')
|
self.console.info(f'\nValidating {f}...')
|
||||||
self.validator.args.save_json = True
|
|
||||||
self.metrics = self.validator(model=f)
|
self.metrics = self.validator(model=f)
|
||||||
self.metrics.pop('fitness', None)
|
self.metrics.pop('fitness', None)
|
||||||
self.run_callbacks('on_fit_epoch_end')
|
self.run_callbacks('on_fit_epoch_end')
|
||||||
|
@ -5,6 +5,7 @@ import inspect
|
|||||||
import logging.config
|
import logging.config
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
@ -14,7 +15,6 @@ from types import SimpleNamespace
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import git
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@ -124,7 +124,7 @@ def is_colab():
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if running inside a Colab notebook, False otherwise.
|
bool: True if running inside a Colab notebook, False otherwise.
|
||||||
"""
|
"""
|
||||||
# Check if the google.colab module is present in sys.modules
|
# Check if the 'google.colab' module is present in sys.modules
|
||||||
return 'google.colab' in sys.modules
|
return 'google.colab' in sys.modules
|
||||||
|
|
||||||
|
|
||||||
@ -138,7 +138,7 @@ def is_kaggle():
|
|||||||
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
||||||
|
|
||||||
|
|
||||||
def is_jupyter_notebook():
|
def is_jupyter():
|
||||||
"""
|
"""
|
||||||
Check if the current script is running inside a Jupyter Notebook.
|
Check if the current script is running inside a Jupyter Notebook.
|
||||||
Verified on Colab, Jupyterlab, Kaggle, Paperspace.
|
Verified on Colab, Jupyterlab, Kaggle, Paperspace.
|
||||||
@ -146,8 +146,6 @@ def is_jupyter_notebook():
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if running inside a Jupyter Notebook, False otherwise.
|
bool: True if running inside a Jupyter Notebook, False otherwise.
|
||||||
"""
|
"""
|
||||||
# Check if the get_ipython function exists
|
|
||||||
# (it does not exist when running as a standalone script)
|
|
||||||
try:
|
try:
|
||||||
from IPython import get_ipython
|
from IPython import get_ipython
|
||||||
return get_ipython() is not None
|
return get_ipython() is not None
|
||||||
@ -170,21 +168,6 @@ def is_docker() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_git_directory() -> bool:
|
|
||||||
"""
|
|
||||||
Check if the current working directory is inside a git repository.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the current working directory is inside a git repository, False otherwise.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
git.Repo(search_parent_directories=True)
|
|
||||||
# subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True) # CLI alternative
|
|
||||||
return True
|
|
||||||
except git.exc.InvalidGitRepositoryError: # subprocess.CalledProcessError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def is_pip_package(filepath: str = __name__) -> bool:
|
def is_pip_package(filepath: str = __name__) -> bool:
|
||||||
"""
|
"""
|
||||||
Determines if the file at the given filepath is part of a pip package.
|
Determines if the file at the given filepath is part of a pip package.
|
||||||
@ -224,8 +207,10 @@ def is_dir_writeable(dir_path: Union[str, Path]) -> bool:
|
|||||||
|
|
||||||
def is_pytest_running():
|
def is_pytest_running():
|
||||||
"""
|
"""
|
||||||
Returns a boolean indicating if pytest is currently running or not
|
Determines whether pytest is currently running or not.
|
||||||
:return: True if pytest is running, False otherwise
|
|
||||||
|
Returns:
|
||||||
|
(bool): True if pytest is running, False otherwise.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
import sys
|
import sys
|
||||||
@ -234,17 +219,53 @@ def is_pytest_running():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_git_root_dir():
|
def is_github_actions_ci() -> bool:
|
||||||
|
"""
|
||||||
|
Determine if the current environment is a GitHub Actions CI Python runner.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(bool): True if the current environment is a GitHub Actions CI Python runner, False otherwise.
|
||||||
|
"""
|
||||||
|
return 'GITHUB_ACTIONS' in os.environ and 'RUNNER_OS' in os.environ and 'RUNNER_TOOL_CACHE' in os.environ
|
||||||
|
|
||||||
|
|
||||||
|
def is_git_dir():
|
||||||
|
"""
|
||||||
|
Determines whether the current file is part of a git repository.
|
||||||
|
If the current file is not part of a git repository, returns None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(bool): True if current file is part of a git repository.
|
||||||
|
"""
|
||||||
|
return get_git_dir() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_dir():
|
||||||
"""
|
"""
|
||||||
Determines whether the current file is part of a git repository and if so, returns the repository root directory.
|
Determines whether the current file is part of a git repository and if so, returns the repository root directory.
|
||||||
If the current file is not part of a git repository, returns None.
|
If the current file is not part of a git repository, returns None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Path) or (None): Git root directory if found or None if not found.
|
||||||
"""
|
"""
|
||||||
try:
|
for d in Path(__file__).parents:
|
||||||
# output = subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True)
|
if (d / '.git').is_dir():
|
||||||
# return Path(output.stdout.strip().decode('utf-8')).parent.resolve() # CLI alternative
|
return d
|
||||||
return Path(git.Repo(search_parent_directories=True).working_tree_dir)
|
return None # no .git dir found
|
||||||
except git.exc.InvalidGitRepositoryError: # (subprocess.CalledProcessError, FileNotFoundError):
|
|
||||||
return None
|
|
||||||
|
def get_git_origin_url():
|
||||||
|
"""
|
||||||
|
Retrieves the origin URL of a git repository.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(str) or (None): The origin URL of the git repository.
|
||||||
|
"""
|
||||||
|
if is_git_dir():
|
||||||
|
with contextlib.suppress(subprocess.CalledProcessError):
|
||||||
|
origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"])
|
||||||
|
return origin.decode().strip()
|
||||||
|
return None # if not git dir or on error
|
||||||
|
|
||||||
|
|
||||||
def get_default_args(func):
|
def get_default_args(func):
|
||||||
@ -316,7 +337,7 @@ def colorstr(*input):
|
|||||||
"bright_white": "\033[97m",
|
"bright_white": "\033[97m",
|
||||||
"end": "\033[0m", # misc
|
"end": "\033[0m", # misc
|
||||||
"bold": "\033[1m",
|
"bold": "\033[1m",
|
||||||
"underline": "\033[4m",}
|
"underline": "\033[4m"}
|
||||||
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
|
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
|
||||||
|
|
||||||
|
|
||||||
@ -334,12 +355,12 @@ def set_logging(name=LOGGING_NAME, verbose=True):
|
|||||||
name: {
|
name: {
|
||||||
"class": "logging.StreamHandler",
|
"class": "logging.StreamHandler",
|
||||||
"formatter": name,
|
"formatter": name,
|
||||||
"level": level,}},
|
"level": level}},
|
||||||
"loggers": {
|
"loggers": {
|
||||||
name: {
|
name: {
|
||||||
"level": level,
|
"level": level,
|
||||||
"handlers": [name],
|
"handlers": [name],
|
||||||
"propagate": False,}}})
|
"propagate": False}}})
|
||||||
|
|
||||||
|
|
||||||
class TryExcept(contextlib.ContextDecorator):
|
class TryExcept(contextlib.ContextDecorator):
|
||||||
@ -419,20 +440,34 @@ def yaml_print(yaml_file: Union[str, Path, dict]) -> None:
|
|||||||
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
|
LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
|
||||||
|
|
||||||
|
|
||||||
def set_sentry(dsn=None):
|
def set_sentry():
|
||||||
"""
|
"""
|
||||||
Initialize the Sentry SDK for error tracking and reporting if pytest is not currently running.
|
Initialize the Sentry SDK for error tracking and reporting if pytest is not currently running.
|
||||||
"""
|
"""
|
||||||
if dsn and not is_pytest_running():
|
|
||||||
|
def before_send(event, hint):
|
||||||
|
if is_git_dir() and get_git_origin_url() != "https://github.com/ultralytics/ultralytics.git":
|
||||||
|
return None
|
||||||
|
event_os = 'colab' if is_colab() else 'kaggle' if is_kaggle() else 'jupyter' if is_jupyter() else \
|
||||||
|
'docker' if is_docker() else platform.system()
|
||||||
|
event['tags'] = {
|
||||||
|
"sys_argv": sys.argv[0],
|
||||||
|
"sys_argv_name": Path(sys.argv[0]).name,
|
||||||
|
"install": 'git' if is_git_dir() else 'pip' if is_pip_package() else 'other',
|
||||||
|
"os": event_os}
|
||||||
|
return event
|
||||||
|
|
||||||
|
if SETTINGS['sync'] and not is_pytest_running() or is_github_actions_ci():
|
||||||
import sentry_sdk # noqa
|
import sentry_sdk # noqa
|
||||||
|
|
||||||
import ultralytics
|
import ultralytics
|
||||||
sentry_sdk.init(
|
sentry_sdk.init(
|
||||||
dsn=dsn,
|
dsn="https://1f331c322109416595df20a91f4005d3@o4504521589325824.ingest.sentry.io/4504521592406016",
|
||||||
debug=False,
|
debug=False,
|
||||||
traces_sample_rate=1.0,
|
traces_sample_rate=1.0,
|
||||||
release=ultralytics.__version__,
|
release=ultralytics.__version__,
|
||||||
environment='production', # 'dev' or 'production'
|
environment='production', # 'dev' or 'production'
|
||||||
|
before_send=before_send,
|
||||||
ignore_errors=[KeyboardInterrupt])
|
ignore_errors=[KeyboardInterrupt])
|
||||||
|
|
||||||
|
|
||||||
@ -450,9 +485,9 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
|
|||||||
from ultralytics.yolo.utils.checks import check_version
|
from ultralytics.yolo.utils.checks import check_version
|
||||||
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
||||||
|
|
||||||
is_git = is_git_directory() # True if ultralytics installed via git
|
git_dir = get_git_dir()
|
||||||
root = get_git_root_dir() if is_git else Path()
|
root = git_dir or Path()
|
||||||
datasets_root = (root.parent if (is_git and is_dir_writeable(root.parent)) else root).resolve()
|
datasets_root = (root.parent if git_dir and is_dir_writeable(root.parent) else root).resolve()
|
||||||
defaults = {
|
defaults = {
|
||||||
'datasets_dir': str(datasets_root / 'datasets'), # default datasets directory.
|
'datasets_dir': str(datasets_root / 'datasets'), # default datasets directory.
|
||||||
'weights_dir': str(root / 'weights'), # default weights directory.
|
'weights_dir': str(root / 'weights'), # default weights directory.
|
||||||
@ -464,13 +499,13 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml', version='0.0.1'):
|
|||||||
with torch_distributed_zero_first(RANK):
|
with torch_distributed_zero_first(RANK):
|
||||||
if not file.exists():
|
if not file.exists():
|
||||||
yaml_save(file, defaults)
|
yaml_save(file, defaults)
|
||||||
|
|
||||||
settings = yaml_load(file)
|
settings = yaml_load(file)
|
||||||
|
|
||||||
# Check that settings keys and types match defaults
|
# Check that settings keys and types match defaults
|
||||||
correct = settings.keys() == defaults.keys() \
|
correct = \
|
||||||
and all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) \
|
settings.keys() == defaults.keys() \
|
||||||
and check_version(settings['settings_version'], version)
|
and all(type(a) == type(b) for a, b in zip(settings.values(), defaults.values())) \
|
||||||
|
and check_version(settings['settings_version'], version)
|
||||||
if not correct:
|
if not correct:
|
||||||
LOGGER.warning('WARNING ⚠️ Ultralytics settings reset to defaults. '
|
LOGGER.warning('WARNING ⚠️ Ultralytics settings reset to defaults. '
|
||||||
'\nThis is normal and may be due to a recent ultralytics package update, '
|
'\nThis is normal and may be due to a recent ultralytics package update, '
|
||||||
|
@ -19,7 +19,7 @@ import torch
|
|||||||
from IPython import display
|
from IPython import display
|
||||||
|
|
||||||
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis,
|
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis,
|
||||||
is_colab, is_docker, is_jupyter_notebook)
|
is_colab, is_docker, is_jupyter)
|
||||||
|
|
||||||
|
|
||||||
def is_ascii(s) -> bool:
|
def is_ascii(s) -> bool:
|
||||||
@ -238,7 +238,7 @@ def check_yaml(file, suffix=('.yaml', '.yml')):
|
|||||||
def check_imshow(warn=False):
|
def check_imshow(warn=False):
|
||||||
# Check if environment supports image displays
|
# Check if environment supports image displays
|
||||||
try:
|
try:
|
||||||
assert not is_jupyter_notebook()
|
assert not is_jupyter()
|
||||||
assert not is_docker()
|
assert not is_docker()
|
||||||
cv2.imshow('test', np.zeros((1, 1, 3)))
|
cv2.imshow('test', np.zeros((1, 1, 3)))
|
||||||
cv2.waitKey(1)
|
cv2.waitKey(1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user