mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-07-08 22:54:54 +08:00
imgsz
warning fix, download function consolidation (#681)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: HaeJin Lee <seareale@gmail.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
0609561549
commit
899abe9f82
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.0.22"
|
__version__ = "8.0.23"
|
||||||
|
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
from ultralytics.yolo.utils import ops
|
from ultralytics.yolo.utils import ops
|
||||||
|
@ -14,7 +14,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from ultralytics.yolo.utils import LOGGER, ROOT, yaml_load
|
from ultralytics.yolo.utils import LOGGER, ROOT, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version
|
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version
|
||||||
from ultralytics.yolo.utils.downloads import attempt_download, is_url
|
from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
|
||||||
from ultralytics.yolo.utils.ops import xywh2xyxy
|
from ultralytics.yolo.utils.ops import xywh2xyxy
|
||||||
|
|
||||||
|
|
||||||
@ -58,7 +58,7 @@ class AutoBackend(nn.Module):
|
|||||||
model = None # TODO: resolves ONNX inference, verify effect on other backends
|
model = None # TODO: resolves ONNX inference, verify effect on other backends
|
||||||
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
||||||
if not (pt or triton or nn_module):
|
if not (pt or triton or nn_module):
|
||||||
w = attempt_download(w) # download if not local
|
w = attempt_download_asset(w) # download if not local
|
||||||
|
|
||||||
# NOTE: special case: in-memory pytorch model
|
# NOTE: special case: in-memory pytorch model
|
||||||
if nn_module:
|
if nn_module:
|
||||||
|
@ -325,9 +325,9 @@ def torch_safe_load(weight):
|
|||||||
Returns:
|
Returns:
|
||||||
The loaded PyTorch model.
|
The loaded PyTorch model.
|
||||||
"""
|
"""
|
||||||
from ultralytics.yolo.utils.downloads import attempt_download
|
from ultralytics.yolo.utils.downloads import attempt_download_asset
|
||||||
|
|
||||||
file = attempt_download(weight) # search online if missing locally
|
file = attempt_download_asset(weight) # search online if missing locally
|
||||||
try:
|
try:
|
||||||
return torch.load(file, map_location='cpu') # load
|
return torch.load(file, map_location='cpu') # load
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
|
@ -90,7 +90,7 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG, override
|
|||||||
|
|
||||||
# Type checks
|
# Type checks
|
||||||
for k in 'project', 'name':
|
for k in 'project', 'name':
|
||||||
if isinstance(cfg[k], (int, float)):
|
if k in cfg and isinstance(cfg[k], (int, float)):
|
||||||
cfg[k] = str(cfg[k])
|
cfg[k] = str(cfg[k])
|
||||||
|
|
||||||
# Return instance
|
# Return instance
|
||||||
@ -176,7 +176,7 @@ def entrypoint(debug=False):
|
|||||||
'version': lambda: LOGGER.info(__version__),
|
'version': lambda: LOGGER.info(__version__),
|
||||||
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
|
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
|
||||||
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
||||||
'copy-cfg': copy_default_config}
|
'copy-cfg': copy_default_cfg}
|
||||||
|
|
||||||
overrides = {} # basic overrides, i.e. imgsz=320
|
overrides = {} # basic overrides, i.e. imgsz=320
|
||||||
for a in merge_equals_args(args): # merge spaces around '=' sign
|
for a in merge_equals_args(args): # merge spaces around '=' sign
|
||||||
@ -221,7 +221,7 @@ def entrypoint(debug=False):
|
|||||||
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='mnist160')
|
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='mnist160')
|
||||||
|
|
||||||
# Mode
|
# Mode
|
||||||
mode = overrides['mode']
|
mode = overrides.get('mode', None)
|
||||||
if mode is None:
|
if mode is None:
|
||||||
mode = DEFAULT_CFG.mode or 'predict'
|
mode = DEFAULT_CFG.mode or 'predict'
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'mode=' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'mode=' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
|
||||||
@ -266,7 +266,7 @@ def entrypoint(debug=False):
|
|||||||
|
|
||||||
|
|
||||||
# Special modes --------------------------------------------------------------------------------------------------------
|
# Special modes --------------------------------------------------------------------------------------------------------
|
||||||
def copy_default_config():
|
def copy_default_cfg():
|
||||||
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
||||||
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
||||||
LOGGER.info(f"{PREFIX}{DEFAULT_CFG_PATH} copied to {new_file}\n"
|
LOGGER.info(f"{PREFIX}{DEFAULT_CFG_PATH} copied to {new_file}\n"
|
||||||
|
@ -44,8 +44,20 @@ class Compose:
|
|||||||
self.transforms = transforms
|
self.transforms = transforms
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
|
mosaic_p = None
|
||||||
|
mosaic_imgsz = None
|
||||||
|
|
||||||
for t in self.transforms:
|
for t in self.transforms:
|
||||||
|
if isinstance(t, Mosaic):
|
||||||
|
temp = t(data)
|
||||||
|
mosaic_p = False if temp == data else True
|
||||||
|
mosaic_imgsz = t.imgsz
|
||||||
|
data = temp
|
||||||
|
else:
|
||||||
|
if isinstance(t, RandomPerspective):
|
||||||
|
t.border = [-mosaic_imgsz // 2, -mosaic_imgsz // 2] if mosaic_p else [0, 0]
|
||||||
data = t(data)
|
data = t(data)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def append(self, transform):
|
def append(self, transform):
|
||||||
|
@ -120,7 +120,8 @@ class BaseDataset(Dataset):
|
|||||||
im = np.load(fn)
|
im = np.load(fn)
|
||||||
else: # read image
|
else: # read image
|
||||||
im = cv2.imread(f) # BGR
|
im = cv2.imread(f) # BGR
|
||||||
assert im is not None, f"Image Not Found {f}"
|
if im is None:
|
||||||
|
raise FileNotFoundError(f"Image Not Found {f}")
|
||||||
h0, w0 = im.shape[:2] # orig hw
|
h0, w0 = im.shape[:2] # orig hw
|
||||||
r = self.imgsz / max(h0, w0) # ratio
|
r = self.imgsz / max(h0, w0) # ratio
|
||||||
if r != 1: # if sizes are not equal
|
if r != 1: # if sizes are not equal
|
||||||
|
@ -65,7 +65,7 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
|
|||||||
assert mode in ["train", "val"]
|
assert mode in ["train", "val"]
|
||||||
shuffle = mode == "train"
|
shuffle = mode == "train"
|
||||||
if cfg.rect and shuffle:
|
if cfg.rect and shuffle:
|
||||||
LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
|
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||||
shuffle = False
|
shuffle = False
|
||||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||||
dataset = YOLODataset(
|
dataset = YOLODataset(
|
||||||
|
@ -64,7 +64,7 @@ download: |
|
|||||||
# Download
|
# Download
|
||||||
dir = Path(yaml['path']) # dataset root dir
|
dir = Path(yaml['path']) # dataset root dir
|
||||||
urls = ['https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip']
|
urls = ['https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip']
|
||||||
download(urls, dir=dir, delete=False)
|
download(urls, dir=dir)
|
||||||
|
|
||||||
# Convert
|
# Convert
|
||||||
annotations_dir = 'Argoverse-HD/annotations/'
|
annotations_dir = 'Argoverse-HD/annotations/'
|
||||||
|
@ -411,12 +411,12 @@ download: |
|
|||||||
# Download
|
# Download
|
||||||
url = f"https://dorc.ks3-cn-beijing.ksyun.com/data-set/2020Objects365%E6%95%B0%E6%8D%AE%E9%9B%86/{split}/"
|
url = f"https://dorc.ks3-cn-beijing.ksyun.com/data-set/2020Objects365%E6%95%B0%E6%8D%AE%E9%9B%86/{split}/"
|
||||||
if split == 'train':
|
if split == 'train':
|
||||||
download([f'{url}zhiyuan_objv2_{split}.tar.gz'], dir=dir, delete=False) # annotations json
|
download([f'{url}zhiyuan_objv2_{split}.tar.gz'], dir=dir) # annotations json
|
||||||
download([f'{url}patch{i}.tar.gz' for i in range(patches)], dir=images, curl=True, delete=False, threads=8)
|
download([f'{url}patch{i}.tar.gz' for i in range(patches)], dir=images, curl=True, threads=8)
|
||||||
elif split == 'val':
|
elif split == 'val':
|
||||||
download([f'{url}zhiyuan_objv2_{split}.json'], dir=dir, delete=False) # annotations json
|
download([f'{url}zhiyuan_objv2_{split}.json'], dir=dir) # annotations json
|
||||||
download([f'{url}images/v1/patch{i}.tar.gz' for i in range(15 + 1)], dir=images, curl=True, delete=False, threads=8)
|
download([f'{url}images/v1/patch{i}.tar.gz' for i in range(15 + 1)], dir=images, curl=True, threads=8)
|
||||||
download([f'{url}images/v2/patch{i}.tar.gz' for i in range(16, patches)], dir=images, curl=True, delete=False, threads=8)
|
download([f'{url}images/v2/patch{i}.tar.gz' for i in range(16, patches)], dir=images, curl=True, threads=8)
|
||||||
|
|
||||||
# Move
|
# Move
|
||||||
for f in tqdm(images.rglob('*.jpg'), desc=f'Moving {split} images'):
|
for f in tqdm(images.rglob('*.jpg'), desc=f'Moving {split} images'):
|
||||||
|
@ -34,7 +34,7 @@ download: |
|
|||||||
dir = Path(yaml['path']) # dataset root dir
|
dir = Path(yaml['path']) # dataset root dir
|
||||||
parent = Path(dir.parent) # download dir
|
parent = Path(dir.parent) # download dir
|
||||||
urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz']
|
urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz']
|
||||||
download(urls, dir=parent, delete=False)
|
download(urls, dir=parent)
|
||||||
|
|
||||||
# Rename directories
|
# Rename directories
|
||||||
if dir.exists():
|
if dir.exists():
|
||||||
|
@ -81,7 +81,7 @@ download: |
|
|||||||
urls = [f'{url}VOCtrainval_06-Nov-2007.zip', # 446MB, 5012 images
|
urls = [f'{url}VOCtrainval_06-Nov-2007.zip', # 446MB, 5012 images
|
||||||
f'{url}VOCtest_06-Nov-2007.zip', # 438MB, 4953 images
|
f'{url}VOCtest_06-Nov-2007.zip', # 438MB, 4953 images
|
||||||
f'{url}VOCtrainval_11-May-2012.zip'] # 1.95GB, 17126 images
|
f'{url}VOCtrainval_11-May-2012.zip'] # 1.95GB, 17126 images
|
||||||
download(urls, dir=dir / 'images', delete=False, curl=True, threads=3)
|
download(urls, dir=dir / 'images', curl=True, threads=3)
|
||||||
|
|
||||||
# Convert
|
# Convert
|
||||||
path = dir / 'images/VOCdevkit'
|
path = dir / 'images/VOCdevkit'
|
||||||
|
@ -138,7 +138,7 @@ download: |
|
|||||||
# urls = ['https://d307kc0mrhucc3.cloudfront.net/train_labels.zip', # train labels
|
# urls = ['https://d307kc0mrhucc3.cloudfront.net/train_labels.zip', # train labels
|
||||||
# 'https://d307kc0mrhucc3.cloudfront.net/train_images.zip', # 15G, 847 train images
|
# 'https://d307kc0mrhucc3.cloudfront.net/train_images.zip', # 15G, 847 train images
|
||||||
# 'https://d307kc0mrhucc3.cloudfront.net/val_images.zip'] # 5G, 282 val images (no labels)
|
# 'https://d307kc0mrhucc3.cloudfront.net/val_images.zip'] # 5G, 282 val images (no labels)
|
||||||
# download(urls, dir=dir, delete=False)
|
# download(urls, dir=dir)
|
||||||
|
|
||||||
# Convert labels
|
# Convert labels
|
||||||
convert_labels(dir / 'xView_train.geojson')
|
convert_labels(dir / 'xView_train.geojson')
|
||||||
|
@ -237,11 +237,7 @@ def check_det_dataset(dataset, autodownload=True):
|
|||||||
raise FileNotFoundError(msg)
|
raise FileNotFoundError(msg)
|
||||||
t = time.time()
|
t = time.time()
|
||||||
if s.startswith('http') and s.endswith('.zip'): # URL
|
if s.startswith('http') and s.endswith('.zip'): # URL
|
||||||
f = Path(s).name # filename
|
safe_download(url=s, dir=DATASETS_DIR, delete=True)
|
||||||
safe_download(file=f, url=s)
|
|
||||||
Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
|
|
||||||
unzip_file(f, path=DATASETS_DIR) # unzip
|
|
||||||
Path(f).unlink() # remove zip
|
|
||||||
r = None # success
|
r = None # success
|
||||||
elif s.startswith('bash '): # bash script
|
elif s.startswith('bash '): # bash script
|
||||||
LOGGER.info(f'Running {s} ...')
|
LOGGER.info(f'Running {s} ...')
|
||||||
@ -251,7 +247,7 @@ def check_det_dataset(dataset, autodownload=True):
|
|||||||
dt = f'({round(time.time() - t, 1)}s)'
|
dt = f'({round(time.time() - t, 1)}s)'
|
||||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
||||||
LOGGER.info(f"Dataset download {s}")
|
LOGGER.info(f"Dataset download {s}")
|
||||||
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
|
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
|
||||||
|
|
||||||
return data # dictionary
|
return data # dictionary
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ class Exporter:
|
|||||||
Initializes the Exporter class.
|
Initializes the Exporter class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.args = get_cfg(cfg, overrides)
|
self.args = get_cfg(cfg, overrides)
|
||||||
|
@ -28,7 +28,6 @@ DEFAULT_CFG_PATH = ROOT / "yolo/cfg/default.yaml"
|
|||||||
RANK = int(os.getenv('RANK', -1))
|
RANK = int(os.getenv('RANK', -1))
|
||||||
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
||||||
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
||||||
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
|
|
||||||
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
|
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
|
||||||
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
|
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
|
||||||
LOGGING_NAME = 'ultralytics'
|
LOGGING_NAME = 'ultralytics'
|
||||||
@ -328,6 +327,20 @@ def get_git_origin_url():
|
|||||||
return None # if not git dir or on error
|
return None # if not git dir or on error
|
||||||
|
|
||||||
|
|
||||||
|
def get_git_branch():
|
||||||
|
"""
|
||||||
|
Returns the current git branch name. If not in a git repository, returns None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(str) or (None): The current git branch name.
|
||||||
|
"""
|
||||||
|
if is_git_dir():
|
||||||
|
with contextlib.suppress(subprocess.CalledProcessError):
|
||||||
|
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
||||||
|
return origin.decode().strip()
|
||||||
|
return None # if not git dir or on error
|
||||||
|
|
||||||
|
|
||||||
def get_default_args(func):
|
def get_default_args(func):
|
||||||
# Get func() default arguments
|
# Get func() default arguments
|
||||||
signature = inspect.signature(func)
|
signature = inspect.signature(func)
|
||||||
@ -466,7 +479,8 @@ def set_sentry():
|
|||||||
if SETTINGS['sync'] and \
|
if SETTINGS['sync'] and \
|
||||||
not is_pytest_running() and \
|
not is_pytest_running() and \
|
||||||
not is_github_actions_ci() and \
|
not is_github_actions_ci() and \
|
||||||
(is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git"):
|
(is_pip_package() or
|
||||||
|
(get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")):
|
||||||
import sentry_sdk # noqa
|
import sentry_sdk # noqa
|
||||||
|
|
||||||
import ultralytics
|
import ultralytics
|
||||||
|
@ -28,7 +28,7 @@ def autobatch(model, imgsz=640, fraction=0.7, batch_size=16):
|
|||||||
|
|
||||||
# Check device
|
# Check device
|
||||||
prefix = colorstr('AutoBatch: ')
|
prefix = colorstr('AutoBatch: ')
|
||||||
LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
|
LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}')
|
||||||
device = next(model.parameters()).device # get model device
|
device = next(model.parameters()).device # get model device
|
||||||
if device.type == 'cpu':
|
if device.type == 'cpu':
|
||||||
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
|
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
|
||||||
|
@ -17,9 +17,10 @@ import pkg_resources as pkg
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
from IPython import display
|
from IPython import display
|
||||||
|
from matplotlib import font_manager
|
||||||
|
|
||||||
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads,
|
from ultralytics.yolo.utils import (AUTOINSTALL, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads, emojis,
|
||||||
emojis, is_colab, is_docker, is_jupyter)
|
is_colab, is_docker, is_jupyter)
|
||||||
|
|
||||||
|
|
||||||
def is_ascii(s) -> bool:
|
def is_ascii(s) -> bool:
|
||||||
@ -57,15 +58,14 @@ def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
|
|||||||
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
|
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
|
||||||
|
|
||||||
# Convert image size to list if it is an integer
|
# Convert image size to list if it is an integer
|
||||||
if isinstance(imgsz, int):
|
imgsz = [imgsz] if isinstance(imgsz, int) else list(imgsz)
|
||||||
imgsz = [imgsz]
|
|
||||||
|
|
||||||
# Make image size a multiple of the stride
|
# Make image size a multiple of the stride
|
||||||
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
|
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
|
||||||
|
|
||||||
# Print warning message if image size was updated
|
# Print warning message if image size was updated
|
||||||
if sz != imgsz:
|
if sz != imgsz:
|
||||||
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {stride}, updating to {sz}')
|
LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}')
|
||||||
|
|
||||||
# Add missing dimensions if necessary
|
# Add missing dimensions if necessary
|
||||||
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
|
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
|
||||||
@ -104,26 +104,33 @@ def check_version(current: str = "0.0.0",
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def check_font(font: str = FONT, progress: bool = False) -> None:
|
def check_font(font='Arial.ttf'):
|
||||||
"""
|
"""
|
||||||
Download font file to the user's configuration directory if it does not already exist.
|
Find font locally or download to user's configuration directory if it does not already exist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
font (str): Path to font file.
|
font (str): Path or name of font.
|
||||||
progress (bool): If True, display a progress bar during the download.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
file (Path): Resolved font file path.
|
||||||
"""
|
"""
|
||||||
font = Path(font)
|
name = Path(font).name
|
||||||
|
|
||||||
# Destination path for the font file
|
# Check USER_CONFIG_DIR
|
||||||
file = USER_CONFIG_DIR / font.name
|
file = USER_CONFIG_DIR / name
|
||||||
|
if file.exists():
|
||||||
|
return file
|
||||||
|
|
||||||
# Check if font file exists at the source or destination path
|
# Check system fonts
|
||||||
if not font.exists() and not file.exists():
|
matches = [s for s in font_manager.findSystemFonts() if font in s]
|
||||||
# Download font file
|
if any(matches):
|
||||||
downloads.safe_download(file=file, url=f'https://ultralytics.com/assets/{font.name}', progress=progress)
|
return matches[0]
|
||||||
|
|
||||||
|
# Download to USER_CONFIG_DIR if missing
|
||||||
|
url = f'https://ultralytics.com/assets/{name}'
|
||||||
|
if downloads.is_url(url):
|
||||||
|
downloads.safe_download(url=url, file=file)
|
||||||
|
return file
|
||||||
|
|
||||||
|
|
||||||
def check_online() -> bool:
|
def check_online() -> bool:
|
||||||
@ -213,7 +220,7 @@ def check_file(file, suffix=''):
|
|||||||
if Path(file).is_file():
|
if Path(file).is_file():
|
||||||
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
||||||
else:
|
else:
|
||||||
downloads.safe_download(file=file, url=url)
|
downloads.safe_download(url=url, file=file)
|
||||||
return file
|
return file
|
||||||
else: # search
|
else: # search
|
||||||
files = []
|
files = []
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
import logging
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import urllib
|
import urllib
|
||||||
@ -15,27 +15,6 @@ import torch
|
|||||||
from ultralytics.yolo.utils import LOGGER
|
from ultralytics.yolo.utils import LOGGER
|
||||||
|
|
||||||
|
|
||||||
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg='', progress=True):
|
|
||||||
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
|
|
||||||
file = Path(file)
|
|
||||||
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
|
|
||||||
try: # url1
|
|
||||||
LOGGER.info(f'Downloading {url} to {file}...')
|
|
||||||
torch.hub.download_url_to_file(url, str(file), progress=progress and LOGGER.level <= logging.INFO)
|
|
||||||
assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
|
|
||||||
except Exception as e: # url2
|
|
||||||
if file.exists():
|
|
||||||
file.unlink() # remove partial downloads
|
|
||||||
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
|
|
||||||
os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
|
|
||||||
finally:
|
|
||||||
if not file.exists() or file.stat().st_size < min_bytes: # check
|
|
||||||
if file.exists():
|
|
||||||
file.unlink() # remove partial downloads
|
|
||||||
LOGGER.warning(f"ERROR: {assert_msg}\n{error_msg}")
|
|
||||||
LOGGER.info('')
|
|
||||||
|
|
||||||
|
|
||||||
def is_url(url, check=True):
|
def is_url(url, check=True):
|
||||||
# Check if string is URL and check if URL exists
|
# Check if string is URL and check if URL exists
|
||||||
try:
|
try:
|
||||||
@ -47,7 +26,71 @@ def is_url(url, check=True):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
|
def safe_download(url,
|
||||||
|
file=None,
|
||||||
|
dir=None,
|
||||||
|
unzip=True,
|
||||||
|
delete=False,
|
||||||
|
curl=False,
|
||||||
|
retry=3,
|
||||||
|
min_bytes=1E0,
|
||||||
|
progress=True):
|
||||||
|
"""
|
||||||
|
Function for downloading files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: str: The URL of the file to be downloaded.
|
||||||
|
file: str, optional: The filename of the downloaded file.
|
||||||
|
If not provided, the file will be saved with the same name as the URL.
|
||||||
|
dir: str, optional: The directory to save the downloaded file.
|
||||||
|
If not provided, the file will be saved in the current working directory.
|
||||||
|
unzip: bool, optional: Whether to unzip the downloaded file. Default: True.
|
||||||
|
delete: bool, optional: Whether to delete the downloaded file after unzipping. Default: False.
|
||||||
|
curl: bool, optional: Whether to use curl command line tool for downloading. Default: False.
|
||||||
|
retry: int, optional: The number of times to retry the download in case of failure. Default: 3.
|
||||||
|
min_bytes: float, optional: The minimum number of bytes that the downloaded file should have, to be considered
|
||||||
|
a successful download. Default: 1E0.
|
||||||
|
progress: bool, optional: Whether to display a progress bar during the download. Default: True.
|
||||||
|
"""
|
||||||
|
if '://' not in str(url) and Path(url).is_file(): # exists ('://' check required in Windows Python<3.10)
|
||||||
|
f = Path(url) # filename
|
||||||
|
else: # does not exist
|
||||||
|
assert dir or file, 'dir or file required for download'
|
||||||
|
f = dir / Path(url).name if dir else Path(file)
|
||||||
|
LOGGER.info(f'Downloading {url} to {f}...')
|
||||||
|
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
|
||||||
|
for i in range(retry + 1):
|
||||||
|
try:
|
||||||
|
if curl or i > 0: # curl download with retry, continue
|
||||||
|
s = 'sS' * (not progress) # silent
|
||||||
|
r = os.system(f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -')
|
||||||
|
else: # torch download
|
||||||
|
r = torch.hub.download_url_to_file(url, f, progress=progress)
|
||||||
|
assert r in {0, None}
|
||||||
|
except Exception as e:
|
||||||
|
if i >= retry:
|
||||||
|
raise ConnectionError(f'❌ Download failure for {url}') from e
|
||||||
|
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
||||||
|
continue
|
||||||
|
|
||||||
|
if f.exists():
|
||||||
|
if f.stat().st_size > min_bytes:
|
||||||
|
break # success
|
||||||
|
f.unlink() # remove partial downloads
|
||||||
|
|
||||||
|
if unzip and f.exists() and f.suffix in {'.zip', '.tar', '.gz'}:
|
||||||
|
LOGGER.info(f'Unzipping {f}...')
|
||||||
|
if f.suffix == '.zip':
|
||||||
|
ZipFile(f).extractall(path=f.parent) # unzip
|
||||||
|
elif f.suffix == '.tar':
|
||||||
|
os.system(f'tar xf {f} --directory {f.parent}') # unzip
|
||||||
|
elif f.suffix == '.gz':
|
||||||
|
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
|
||||||
|
if delete:
|
||||||
|
f.unlink() # remove zip
|
||||||
|
|
||||||
|
|
||||||
|
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||||
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
|
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
|
||||||
from ultralytics.yolo.utils import SETTINGS
|
from ultralytics.yolo.utils import SETTINGS
|
||||||
|
|
||||||
@ -73,7 +116,7 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
|
|||||||
if Path(file).is_file():
|
if Path(file).is_file():
|
||||||
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
||||||
else:
|
else:
|
||||||
safe_download(file=file, url=url, min_bytes=1E5)
|
safe_download(url=url, file=file, min_bytes=1E5)
|
||||||
return file
|
return file
|
||||||
|
|
||||||
# GitHub assets
|
# GitHub assets
|
||||||
@ -91,61 +134,23 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
|
|||||||
|
|
||||||
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|
||||||
if name in assets:
|
if name in assets:
|
||||||
safe_download(file,
|
safe_download(url=f'https://github.com/{repo}/releases/download/{tag}/{name}', file=file, min_bytes=1E5)
|
||||||
url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
|
|
||||||
min_bytes=1E5,
|
|
||||||
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag}')
|
|
||||||
|
|
||||||
return str(file)
|
return str(file)
|
||||||
|
|
||||||
|
|
||||||
def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1, retry=3):
|
def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3):
|
||||||
# Multithreaded file download and unzip function, used in data.yaml for autodownload
|
# Multithreaded file download and unzip function, used in data.yaml for autodownload
|
||||||
def download_one(url, dir):
|
|
||||||
# Download 1 file
|
|
||||||
success = True
|
|
||||||
if '://' not in str(url) and Path(url).is_file(): # exists ('://' check required in Windows Python<3.10)
|
|
||||||
f = Path(url) # filename
|
|
||||||
else: # does not exist
|
|
||||||
f = dir / Path(url).name
|
|
||||||
LOGGER.info(f'Downloading {url} to {f}...')
|
|
||||||
for i in range(retry + 1):
|
|
||||||
if curl: # curl download with retry, continue
|
|
||||||
s = 'sS' * (threads > 1) # silent
|
|
||||||
r = os.system(f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -')
|
|
||||||
success = r == 0
|
|
||||||
else: # torch download
|
|
||||||
torch.hub.download_url_to_file(url, f, progress=threads == 1)
|
|
||||||
success = f.is_file()
|
|
||||||
if success:
|
|
||||||
break
|
|
||||||
elif i < retry:
|
|
||||||
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
|
||||||
else:
|
|
||||||
LOGGER.warning(f'❌ Failed to download {url}...')
|
|
||||||
|
|
||||||
if unzip and success and f.suffix in ('.zip', '.tar', '.gz'):
|
|
||||||
LOGGER.info(f'Unzipping {f}...')
|
|
||||||
if f.suffix == '.zip':
|
|
||||||
ZipFile(f).extractall(path=dir) # unzip
|
|
||||||
elif f.suffix == '.tar':
|
|
||||||
os.system(f'tar xf {f} --directory {f.parent}') # unzip
|
|
||||||
elif f.suffix == '.gz':
|
|
||||||
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
|
|
||||||
if delete:
|
|
||||||
f.unlink() # remove zip
|
|
||||||
|
|
||||||
dir = Path(dir)
|
dir = Path(dir)
|
||||||
dir.mkdir(parents=True, exist_ok=True) # make directory
|
dir.mkdir(parents=True, exist_ok=True) # make directory
|
||||||
if threads > 1:
|
if threads > 1:
|
||||||
# pool = ThreadPool(threads)
|
|
||||||
# pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
|
|
||||||
# pool.close()
|
|
||||||
# pool.join()
|
|
||||||
with ThreadPool(threads) as pool:
|
with ThreadPool(threads) as pool:
|
||||||
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
|
pool.map(
|
||||||
|
lambda x: safe_download(
|
||||||
|
url=x[0], dir=x[1], unzip=unzip, delete=delete, curl=curl, retry=retry, progress=threads <= 1),
|
||||||
|
zip(url, repeat(dir)))
|
||||||
pool.close()
|
pool.close()
|
||||||
pool.join()
|
pool.join()
|
||||||
else:
|
else:
|
||||||
for u in [url] if isinstance(url, (str, Path)) else url:
|
for u in [url] if isinstance(url, (str, Path)) else url:
|
||||||
download_one(u, dir)
|
safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry)
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import math
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from urllib.error import URLError
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -12,9 +11,9 @@ import pandas as pd
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR, threaded
|
from ultralytics.yolo.utils import threaded
|
||||||
|
|
||||||
from .checks import check_font, check_requirements, is_ascii
|
from .checks import check_font, is_ascii
|
||||||
from .files import increment_path
|
from .files import increment_path
|
||||||
from .ops import clip_coords, scale_image, xywh2xyxy, xyxy2xywh
|
from .ops import clip_coords, scale_image, xywh2xyxy, xyxy2xywh
|
||||||
|
|
||||||
@ -49,14 +48,20 @@ class Annotator:
|
|||||||
if self.pil: # use PIL
|
if self.pil: # use PIL
|
||||||
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
||||||
self.draw = ImageDraw.Draw(self.im)
|
self.draw = ImageDraw.Draw(self.im)
|
||||||
self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font,
|
try:
|
||||||
size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
|
font = check_font('Arial.Unicode.ttf' if non_ascii else font)
|
||||||
|
size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
|
||||||
|
self.font = ImageFont.truetype(str(font), size)
|
||||||
|
except Exception:
|
||||||
|
self.font = ImageFont.load_default()
|
||||||
else: # use cv2
|
else: # use cv2
|
||||||
self.im = im
|
self.im = im
|
||||||
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
|
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
|
||||||
|
|
||||||
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
|
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
|
||||||
# Add one xyxy box to image with label
|
# Add one xyxy box to image with label
|
||||||
|
if isinstance(box, torch.Tensor):
|
||||||
|
box = box.tolist()
|
||||||
if self.pil or not is_ascii(label):
|
if self.pil or not is_ascii(label):
|
||||||
self.draw.rectangle(box, width=self.lw, outline=color) # box
|
self.draw.rectangle(box, width=self.lw, outline=color) # box
|
||||||
if label:
|
if label:
|
||||||
@ -139,22 +144,6 @@ class Annotator:
|
|||||||
return np.asarray(self.im)
|
return np.asarray(self.im)
|
||||||
|
|
||||||
|
|
||||||
def check_pil_font(font=FONT, size=10):
|
|
||||||
# Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
|
|
||||||
font = Path(font)
|
|
||||||
font = font if font.exists() else (USER_CONFIG_DIR / font.name)
|
|
||||||
try:
|
|
||||||
return ImageFont.truetype(str(font) if font.exists() else font.name, size)
|
|
||||||
except Exception: # download if missing
|
|
||||||
try:
|
|
||||||
check_font(font)
|
|
||||||
return ImageFont.truetype(str(font), size)
|
|
||||||
except TypeError:
|
|
||||||
check_requirements('Pillow>=8.4.0') # known issue https://github.com/ultralytics/yolov5/issues/5374
|
|
||||||
except URLError: # not online
|
|
||||||
return ImageFont.load_default()
|
|
||||||
|
|
||||||
|
|
||||||
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
||||||
# Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
|
# Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
|
||||||
xyxy = torch.tensor(xyxy).view(-1, 4)
|
xyxy = torch.tensor(xyxy).view(-1, 4)
|
||||||
|
@ -85,8 +85,8 @@ def select_device(device='', batch=0, newline=False):
|
|||||||
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||||
n = len(devices) # device count
|
n = len(devices) # device count
|
||||||
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
|
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
|
||||||
raise ValueError(f'batch={batch} is not multiple of GPU count {n}.\n'
|
raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
|
||||||
f'Try batch={batch // n} or batch={batch // n + 1}')
|
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.")
|
||||||
space = ' ' * (len(s) + 1)
|
space = ' ' * (len(s) + 1)
|
||||||
for i, d in enumerate(devices):
|
for i, d in enumerate(devices):
|
||||||
p = torch.cuda.get_device_properties(i)
|
p = torch.cuda.get_device_properties(i)
|
||||||
|
@ -74,7 +74,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
YOLO(model)(**args)
|
YOLO(model)(**args)
|
||||||
else:
|
else:
|
||||||
predictor = ClassificationPredictor(args)
|
predictor = ClassificationPredictor(overrides=args)
|
||||||
predictor.predict_cli()
|
predictor.predict_cli()
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,7 +146,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
YOLO(model).train(**args)
|
YOLO(model).train(**args)
|
||||||
else:
|
else:
|
||||||
trainer = ClassificationTrainer(args)
|
trainer = ClassificationTrainer(overrides=args)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
YOLO(model)(**args)
|
YOLO(model)(**args)
|
||||||
else:
|
else:
|
||||||
predictor = DetectionPredictor(args)
|
predictor = DetectionPredictor(overrides=args)
|
||||||
predictor.predict_cli()
|
predictor.predict_cli()
|
||||||
|
|
||||||
|
|
||||||
|
@ -204,7 +204,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
YOLO(model).train(**args)
|
YOLO(model).train(**args)
|
||||||
else:
|
else:
|
||||||
trainer = DetectionTrainer(args)
|
trainer = DetectionTrainer(overrides=args)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
YOLO(model)(**args)
|
YOLO(model)(**args)
|
||||||
else:
|
else:
|
||||||
predictor = SegmentationPredictor(args)
|
predictor = SegmentationPredictor(overrides=args)
|
||||||
predictor.predict_cli()
|
predictor.predict_cli()
|
||||||
|
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
|
|||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
YOLO(model).train(**args)
|
YOLO(model).train(**args)
|
||||||
else:
|
else:
|
||||||
trainer = SegmentationTrainer(args)
|
trainer = SegmentationTrainer(overrides=args)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user