mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
imgsz
warning fix, download function consolidation (#681)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: HaeJin Lee <seareale@gmail.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
0609561549
commit
899abe9f82
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
__version__ = "8.0.22"
|
||||
__version__ = "8.0.23"
|
||||
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils import ops
|
||||
|
@ -14,7 +14,7 @@ from PIL import Image
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version
|
||||
from ultralytics.yolo.utils.downloads import attempt_download, is_url
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
|
||||
from ultralytics.yolo.utils.ops import xywh2xyxy
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ class AutoBackend(nn.Module):
|
||||
model = None # TODO: resolves ONNX inference, verify effect on other backends
|
||||
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
|
||||
if not (pt or triton or nn_module):
|
||||
w = attempt_download(w) # download if not local
|
||||
w = attempt_download_asset(w) # download if not local
|
||||
|
||||
# NOTE: special case: in-memory pytorch model
|
||||
if nn_module:
|
||||
|
@ -325,9 +325,9 @@ def torch_safe_load(weight):
|
||||
Returns:
|
||||
The loaded PyTorch model.
|
||||
"""
|
||||
from ultralytics.yolo.utils.downloads import attempt_download
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset
|
||||
|
||||
file = attempt_download(weight) # search online if missing locally
|
||||
file = attempt_download_asset(weight) # search online if missing locally
|
||||
try:
|
||||
return torch.load(file, map_location='cpu') # load
|
||||
except ModuleNotFoundError as e:
|
||||
|
@ -90,7 +90,7 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG, override
|
||||
|
||||
# Type checks
|
||||
for k in 'project', 'name':
|
||||
if isinstance(cfg[k], (int, float)):
|
||||
if k in cfg and isinstance(cfg[k], (int, float)):
|
||||
cfg[k] = str(cfg[k])
|
||||
|
||||
# Return instance
|
||||
@ -176,7 +176,7 @@ def entrypoint(debug=False):
|
||||
'version': lambda: LOGGER.info(__version__),
|
||||
'settings': lambda: yaml_print(USER_CONFIG_DIR / 'settings.yaml'),
|
||||
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
||||
'copy-cfg': copy_default_config}
|
||||
'copy-cfg': copy_default_cfg}
|
||||
|
||||
overrides = {} # basic overrides, i.e. imgsz=320
|
||||
for a in merge_equals_args(args): # merge spaces around '=' sign
|
||||
@ -221,7 +221,7 @@ def entrypoint(debug=False):
|
||||
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='mnist160')
|
||||
|
||||
# Mode
|
||||
mode = overrides['mode']
|
||||
mode = overrides.get('mode', None)
|
||||
if mode is None:
|
||||
mode = DEFAULT_CFG.mode or 'predict'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'mode=' is missing. Valid modes are {modes}. Using default 'mode={mode}'.")
|
||||
@ -266,7 +266,7 @@ def entrypoint(debug=False):
|
||||
|
||||
|
||||
# Special modes --------------------------------------------------------------------------------------------------------
|
||||
def copy_default_config():
|
||||
def copy_default_cfg():
|
||||
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
|
||||
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
||||
LOGGER.info(f"{PREFIX}{DEFAULT_CFG_PATH} copied to {new_file}\n"
|
||||
|
@ -44,8 +44,20 @@ class Compose:
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, data):
|
||||
mosaic_p = None
|
||||
mosaic_imgsz = None
|
||||
|
||||
for t in self.transforms:
|
||||
data = t(data)
|
||||
if isinstance(t, Mosaic):
|
||||
temp = t(data)
|
||||
mosaic_p = False if temp == data else True
|
||||
mosaic_imgsz = t.imgsz
|
||||
data = temp
|
||||
else:
|
||||
if isinstance(t, RandomPerspective):
|
||||
t.border = [-mosaic_imgsz // 2, -mosaic_imgsz // 2] if mosaic_p else [0, 0]
|
||||
data = t(data)
|
||||
|
||||
return data
|
||||
|
||||
def append(self, transform):
|
||||
|
@ -120,7 +120,8 @@ class BaseDataset(Dataset):
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
assert im is not None, f"Image Not Found {f}"
|
||||
if im is None:
|
||||
raise FileNotFoundError(f"Image Not Found {f}")
|
||||
h0, w0 = im.shape[:2] # orig hw
|
||||
r = self.imgsz / max(h0, w0) # ratio
|
||||
if r != 1: # if sizes are not equal
|
||||
|
@ -65,7 +65,7 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
|
||||
assert mode in ["train", "val"]
|
||||
shuffle = mode == "train"
|
||||
if cfg.rect and shuffle:
|
||||
LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||
shuffle = False
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = YOLODataset(
|
||||
|
@ -64,7 +64,7 @@ download: |
|
||||
# Download
|
||||
dir = Path(yaml['path']) # dataset root dir
|
||||
urls = ['https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip']
|
||||
download(urls, dir=dir, delete=False)
|
||||
download(urls, dir=dir)
|
||||
|
||||
# Convert
|
||||
annotations_dir = 'Argoverse-HD/annotations/'
|
||||
|
@ -411,12 +411,12 @@ download: |
|
||||
# Download
|
||||
url = f"https://dorc.ks3-cn-beijing.ksyun.com/data-set/2020Objects365%E6%95%B0%E6%8D%AE%E9%9B%86/{split}/"
|
||||
if split == 'train':
|
||||
download([f'{url}zhiyuan_objv2_{split}.tar.gz'], dir=dir, delete=False) # annotations json
|
||||
download([f'{url}patch{i}.tar.gz' for i in range(patches)], dir=images, curl=True, delete=False, threads=8)
|
||||
download([f'{url}zhiyuan_objv2_{split}.tar.gz'], dir=dir) # annotations json
|
||||
download([f'{url}patch{i}.tar.gz' for i in range(patches)], dir=images, curl=True, threads=8)
|
||||
elif split == 'val':
|
||||
download([f'{url}zhiyuan_objv2_{split}.json'], dir=dir, delete=False) # annotations json
|
||||
download([f'{url}images/v1/patch{i}.tar.gz' for i in range(15 + 1)], dir=images, curl=True, delete=False, threads=8)
|
||||
download([f'{url}images/v2/patch{i}.tar.gz' for i in range(16, patches)], dir=images, curl=True, delete=False, threads=8)
|
||||
download([f'{url}zhiyuan_objv2_{split}.json'], dir=dir) # annotations json
|
||||
download([f'{url}images/v1/patch{i}.tar.gz' for i in range(15 + 1)], dir=images, curl=True, threads=8)
|
||||
download([f'{url}images/v2/patch{i}.tar.gz' for i in range(16, patches)], dir=images, curl=True, threads=8)
|
||||
|
||||
# Move
|
||||
for f in tqdm(images.rglob('*.jpg'), desc=f'Moving {split} images'):
|
||||
|
@ -34,7 +34,7 @@ download: |
|
||||
dir = Path(yaml['path']) # dataset root dir
|
||||
parent = Path(dir.parent) # download dir
|
||||
urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz']
|
||||
download(urls, dir=parent, delete=False)
|
||||
download(urls, dir=parent)
|
||||
|
||||
# Rename directories
|
||||
if dir.exists():
|
||||
|
@ -81,7 +81,7 @@ download: |
|
||||
urls = [f'{url}VOCtrainval_06-Nov-2007.zip', # 446MB, 5012 images
|
||||
f'{url}VOCtest_06-Nov-2007.zip', # 438MB, 4953 images
|
||||
f'{url}VOCtrainval_11-May-2012.zip'] # 1.95GB, 17126 images
|
||||
download(urls, dir=dir / 'images', delete=False, curl=True, threads=3)
|
||||
download(urls, dir=dir / 'images', curl=True, threads=3)
|
||||
|
||||
# Convert
|
||||
path = dir / 'images/VOCdevkit'
|
||||
|
@ -138,7 +138,7 @@ download: |
|
||||
# urls = ['https://d307kc0mrhucc3.cloudfront.net/train_labels.zip', # train labels
|
||||
# 'https://d307kc0mrhucc3.cloudfront.net/train_images.zip', # 15G, 847 train images
|
||||
# 'https://d307kc0mrhucc3.cloudfront.net/val_images.zip'] # 5G, 282 val images (no labels)
|
||||
# download(urls, dir=dir, delete=False)
|
||||
# download(urls, dir=dir)
|
||||
|
||||
# Convert labels
|
||||
convert_labels(dir / 'xView_train.geojson')
|
||||
|
@ -237,11 +237,7 @@ def check_det_dataset(dataset, autodownload=True):
|
||||
raise FileNotFoundError(msg)
|
||||
t = time.time()
|
||||
if s.startswith('http') and s.endswith('.zip'): # URL
|
||||
f = Path(s).name # filename
|
||||
safe_download(file=f, url=s)
|
||||
Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
|
||||
unzip_file(f, path=DATASETS_DIR) # unzip
|
||||
Path(f).unlink() # remove zip
|
||||
safe_download(url=s, dir=DATASETS_DIR, delete=True)
|
||||
r = None # success
|
||||
elif s.startswith('bash '): # bash script
|
||||
LOGGER.info(f'Running {s} ...')
|
||||
@ -251,7 +247,7 @@ def check_det_dataset(dataset, autodownload=True):
|
||||
dt = f'({round(time.time() - t, 1)}s)'
|
||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
||||
LOGGER.info(f"Dataset download {s}")
|
||||
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
|
||||
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf') # download fonts
|
||||
|
||||
return data # dictionary
|
||||
|
||||
|
@ -125,7 +125,7 @@ class Exporter:
|
||||
Initializes the Exporter class.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
|
@ -28,7 +28,6 @@ DEFAULT_CFG_PATH = ROOT / "yolo/cfg/default.yaml"
|
||||
RANK = int(os.getenv('RANK', -1))
|
||||
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
||||
AUTOINSTALL = str(os.getenv('YOLO_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
||||
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
|
||||
VERBOSE = str(os.getenv('YOLO_VERBOSE', True)).lower() == 'true' # global verbose mode
|
||||
TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
|
||||
LOGGING_NAME = 'ultralytics'
|
||||
@ -328,6 +327,20 @@ def get_git_origin_url():
|
||||
return None # if not git dir or on error
|
||||
|
||||
|
||||
def get_git_branch():
|
||||
"""
|
||||
Returns the current git branch name. If not in a git repository, returns None.
|
||||
|
||||
Returns:
|
||||
(str) or (None): The current git branch name.
|
||||
"""
|
||||
if is_git_dir():
|
||||
with contextlib.suppress(subprocess.CalledProcessError):
|
||||
origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
||||
return origin.decode().strip()
|
||||
return None # if not git dir or on error
|
||||
|
||||
|
||||
def get_default_args(func):
|
||||
# Get func() default arguments
|
||||
signature = inspect.signature(func)
|
||||
@ -466,7 +479,8 @@ def set_sentry():
|
||||
if SETTINGS['sync'] and \
|
||||
not is_pytest_running() and \
|
||||
not is_github_actions_ci() and \
|
||||
(is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git"):
|
||||
(is_pip_package() or
|
||||
(get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git" and get_git_branch() == "main")):
|
||||
import sentry_sdk # noqa
|
||||
|
||||
import ultralytics
|
||||
|
@ -28,7 +28,7 @@ def autobatch(model, imgsz=640, fraction=0.7, batch_size=16):
|
||||
|
||||
# Check device
|
||||
prefix = colorstr('AutoBatch: ')
|
||||
LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
|
||||
LOGGER.info(f'{prefix}Computing optimal batch size for imgsz={imgsz}')
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type == 'cpu':
|
||||
LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
|
||||
|
@ -17,9 +17,10 @@ import pkg_resources as pkg
|
||||
import psutil
|
||||
import torch
|
||||
from IPython import display
|
||||
from matplotlib import font_manager
|
||||
|
||||
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads,
|
||||
emojis, is_colab, is_docker, is_jupyter)
|
||||
from ultralytics.yolo.utils import (AUTOINSTALL, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads, emojis,
|
||||
is_colab, is_docker, is_jupyter)
|
||||
|
||||
|
||||
def is_ascii(s) -> bool:
|
||||
@ -57,15 +58,14 @@ def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
|
||||
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
|
||||
|
||||
# Convert image size to list if it is an integer
|
||||
if isinstance(imgsz, int):
|
||||
imgsz = [imgsz]
|
||||
imgsz = [imgsz] if isinstance(imgsz, int) else list(imgsz)
|
||||
|
||||
# Make image size a multiple of the stride
|
||||
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
|
||||
|
||||
# Print warning message if image size was updated
|
||||
if sz != imgsz:
|
||||
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {stride}, updating to {sz}')
|
||||
LOGGER.warning(f'WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}')
|
||||
|
||||
# Add missing dimensions if necessary
|
||||
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
|
||||
@ -104,26 +104,33 @@ def check_version(current: str = "0.0.0",
|
||||
return result
|
||||
|
||||
|
||||
def check_font(font: str = FONT, progress: bool = False) -> None:
|
||||
def check_font(font='Arial.ttf'):
|
||||
"""
|
||||
Download font file to the user's configuration directory if it does not already exist.
|
||||
Find font locally or download to user's configuration directory if it does not already exist.
|
||||
|
||||
Args:
|
||||
font (str): Path to font file.
|
||||
progress (bool): If True, display a progress bar during the download.
|
||||
font (str): Path or name of font.
|
||||
|
||||
Returns:
|
||||
None
|
||||
file (Path): Resolved font file path.
|
||||
"""
|
||||
font = Path(font)
|
||||
name = Path(font).name
|
||||
|
||||
# Destination path for the font file
|
||||
file = USER_CONFIG_DIR / font.name
|
||||
# Check USER_CONFIG_DIR
|
||||
file = USER_CONFIG_DIR / name
|
||||
if file.exists():
|
||||
return file
|
||||
|
||||
# Check if font file exists at the source or destination path
|
||||
if not font.exists() and not file.exists():
|
||||
# Download font file
|
||||
downloads.safe_download(file=file, url=f'https://ultralytics.com/assets/{font.name}', progress=progress)
|
||||
# Check system fonts
|
||||
matches = [s for s in font_manager.findSystemFonts() if font in s]
|
||||
if any(matches):
|
||||
return matches[0]
|
||||
|
||||
# Download to USER_CONFIG_DIR if missing
|
||||
url = f'https://ultralytics.com/assets/{name}'
|
||||
if downloads.is_url(url):
|
||||
downloads.safe_download(url=url, file=file)
|
||||
return file
|
||||
|
||||
|
||||
def check_online() -> bool:
|
||||
@ -213,7 +220,7 @@ def check_file(file, suffix=''):
|
||||
if Path(file).is_file():
|
||||
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
||||
else:
|
||||
downloads.safe_download(file=file, url=url)
|
||||
downloads.safe_download(url=url, file=file)
|
||||
return file
|
||||
else: # search
|
||||
files = []
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
import logging
|
||||
import contextlib
|
||||
import os
|
||||
import subprocess
|
||||
import urllib
|
||||
@ -15,27 +15,6 @@ import torch
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
|
||||
|
||||
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg='', progress=True):
|
||||
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
|
||||
file = Path(file)
|
||||
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
|
||||
try: # url1
|
||||
LOGGER.info(f'Downloading {url} to {file}...')
|
||||
torch.hub.download_url_to_file(url, str(file), progress=progress and LOGGER.level <= logging.INFO)
|
||||
assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
|
||||
except Exception as e: # url2
|
||||
if file.exists():
|
||||
file.unlink() # remove partial downloads
|
||||
LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
|
||||
os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
|
||||
finally:
|
||||
if not file.exists() or file.stat().st_size < min_bytes: # check
|
||||
if file.exists():
|
||||
file.unlink() # remove partial downloads
|
||||
LOGGER.warning(f"ERROR: {assert_msg}\n{error_msg}")
|
||||
LOGGER.info('')
|
||||
|
||||
|
||||
def is_url(url, check=True):
|
||||
# Check if string is URL and check if URL exists
|
||||
try:
|
||||
@ -47,7 +26,71 @@ def is_url(url, check=True):
|
||||
return False
|
||||
|
||||
|
||||
def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
def safe_download(url,
|
||||
file=None,
|
||||
dir=None,
|
||||
unzip=True,
|
||||
delete=False,
|
||||
curl=False,
|
||||
retry=3,
|
||||
min_bytes=1E0,
|
||||
progress=True):
|
||||
"""
|
||||
Function for downloading files from a URL, with options for retrying, unzipping, and deleting the downloaded file.
|
||||
|
||||
Args:
|
||||
url: str: The URL of the file to be downloaded.
|
||||
file: str, optional: The filename of the downloaded file.
|
||||
If not provided, the file will be saved with the same name as the URL.
|
||||
dir: str, optional: The directory to save the downloaded file.
|
||||
If not provided, the file will be saved in the current working directory.
|
||||
unzip: bool, optional: Whether to unzip the downloaded file. Default: True.
|
||||
delete: bool, optional: Whether to delete the downloaded file after unzipping. Default: False.
|
||||
curl: bool, optional: Whether to use curl command line tool for downloading. Default: False.
|
||||
retry: int, optional: The number of times to retry the download in case of failure. Default: 3.
|
||||
min_bytes: float, optional: The minimum number of bytes that the downloaded file should have, to be considered
|
||||
a successful download. Default: 1E0.
|
||||
progress: bool, optional: Whether to display a progress bar during the download. Default: True.
|
||||
"""
|
||||
if '://' not in str(url) and Path(url).is_file(): # exists ('://' check required in Windows Python<3.10)
|
||||
f = Path(url) # filename
|
||||
else: # does not exist
|
||||
assert dir or file, 'dir or file required for download'
|
||||
f = dir / Path(url).name if dir else Path(file)
|
||||
LOGGER.info(f'Downloading {url} to {f}...')
|
||||
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
|
||||
for i in range(retry + 1):
|
||||
try:
|
||||
if curl or i > 0: # curl download with retry, continue
|
||||
s = 'sS' * (not progress) # silent
|
||||
r = os.system(f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -')
|
||||
else: # torch download
|
||||
r = torch.hub.download_url_to_file(url, f, progress=progress)
|
||||
assert r in {0, None}
|
||||
except Exception as e:
|
||||
if i >= retry:
|
||||
raise ConnectionError(f'❌ Download failure for {url}') from e
|
||||
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
||||
continue
|
||||
|
||||
if f.exists():
|
||||
if f.stat().st_size > min_bytes:
|
||||
break # success
|
||||
f.unlink() # remove partial downloads
|
||||
|
||||
if unzip and f.exists() and f.suffix in {'.zip', '.tar', '.gz'}:
|
||||
LOGGER.info(f'Unzipping {f}...')
|
||||
if f.suffix == '.zip':
|
||||
ZipFile(f).extractall(path=f.parent) # unzip
|
||||
elif f.suffix == '.tar':
|
||||
os.system(f'tar xf {f} --directory {f.parent}') # unzip
|
||||
elif f.suffix == '.gz':
|
||||
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
|
||||
if delete:
|
||||
f.unlink() # remove zip
|
||||
|
||||
|
||||
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
|
||||
from ultralytics.yolo.utils import SETTINGS
|
||||
|
||||
@ -73,7 +116,7 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
if Path(file).is_file():
|
||||
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
||||
else:
|
||||
safe_download(file=file, url=url, min_bytes=1E5)
|
||||
safe_download(url=url, file=file, min_bytes=1E5)
|
||||
return file
|
||||
|
||||
# GitHub assets
|
||||
@ -91,61 +134,23 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
|
||||
|
||||
file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
|
||||
if name in assets:
|
||||
safe_download(file,
|
||||
url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
|
||||
min_bytes=1E5,
|
||||
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag}')
|
||||
safe_download(url=f'https://github.com/{repo}/releases/download/{tag}/{name}', file=file, min_bytes=1E5)
|
||||
|
||||
return str(file)
|
||||
|
||||
|
||||
def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1, retry=3):
|
||||
def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3):
|
||||
# Multithreaded file download and unzip function, used in data.yaml for autodownload
|
||||
def download_one(url, dir):
|
||||
# Download 1 file
|
||||
success = True
|
||||
if '://' not in str(url) and Path(url).is_file(): # exists ('://' check required in Windows Python<3.10)
|
||||
f = Path(url) # filename
|
||||
else: # does not exist
|
||||
f = dir / Path(url).name
|
||||
LOGGER.info(f'Downloading {url} to {f}...')
|
||||
for i in range(retry + 1):
|
||||
if curl: # curl download with retry, continue
|
||||
s = 'sS' * (threads > 1) # silent
|
||||
r = os.system(f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -')
|
||||
success = r == 0
|
||||
else: # torch download
|
||||
torch.hub.download_url_to_file(url, f, progress=threads == 1)
|
||||
success = f.is_file()
|
||||
if success:
|
||||
break
|
||||
elif i < retry:
|
||||
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
||||
else:
|
||||
LOGGER.warning(f'❌ Failed to download {url}...')
|
||||
|
||||
if unzip and success and f.suffix in ('.zip', '.tar', '.gz'):
|
||||
LOGGER.info(f'Unzipping {f}...')
|
||||
if f.suffix == '.zip':
|
||||
ZipFile(f).extractall(path=dir) # unzip
|
||||
elif f.suffix == '.tar':
|
||||
os.system(f'tar xf {f} --directory {f.parent}') # unzip
|
||||
elif f.suffix == '.gz':
|
||||
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
|
||||
if delete:
|
||||
f.unlink() # remove zip
|
||||
|
||||
dir = Path(dir)
|
||||
dir.mkdir(parents=True, exist_ok=True) # make directory
|
||||
if threads > 1:
|
||||
# pool = ThreadPool(threads)
|
||||
# pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
|
||||
# pool.close()
|
||||
# pool.join()
|
||||
with ThreadPool(threads) as pool:
|
||||
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
|
||||
pool.map(
|
||||
lambda x: safe_download(
|
||||
url=x[0], dir=x[1], unzip=unzip, delete=delete, curl=curl, retry=retry, progress=threads <= 1),
|
||||
zip(url, repeat(dir)))
|
||||
pool.close()
|
||||
pool.join()
|
||||
else:
|
||||
for u in [url] if isinstance(url, (str, Path)) else url:
|
||||
download_one(u, dir)
|
||||
safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry)
|
||||
|
@ -3,7 +3,6 @@
|
||||
import contextlib
|
||||
import math
|
||||
from pathlib import Path
|
||||
from urllib.error import URLError
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
@ -12,9 +11,9 @@ import pandas as pd
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from ultralytics.yolo.utils import FONT, USER_CONFIG_DIR, threaded
|
||||
from ultralytics.yolo.utils import threaded
|
||||
|
||||
from .checks import check_font, check_requirements, is_ascii
|
||||
from .checks import check_font, is_ascii
|
||||
from .files import increment_path
|
||||
from .ops import clip_coords, scale_image, xywh2xyxy, xyxy2xywh
|
||||
|
||||
@ -49,14 +48,20 @@ class Annotator:
|
||||
if self.pil: # use PIL
|
||||
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
||||
self.draw = ImageDraw.Draw(self.im)
|
||||
self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font,
|
||||
size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
|
||||
try:
|
||||
font = check_font('Arial.Unicode.ttf' if non_ascii else font)
|
||||
size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
|
||||
self.font = ImageFont.truetype(str(font), size)
|
||||
except Exception:
|
||||
self.font = ImageFont.load_default()
|
||||
else: # use cv2
|
||||
self.im = im
|
||||
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
|
||||
|
||||
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
|
||||
# Add one xyxy box to image with label
|
||||
if isinstance(box, torch.Tensor):
|
||||
box = box.tolist()
|
||||
if self.pil or not is_ascii(label):
|
||||
self.draw.rectangle(box, width=self.lw, outline=color) # box
|
||||
if label:
|
||||
@ -139,22 +144,6 @@ class Annotator:
|
||||
return np.asarray(self.im)
|
||||
|
||||
|
||||
def check_pil_font(font=FONT, size=10):
|
||||
# Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
|
||||
font = Path(font)
|
||||
font = font if font.exists() else (USER_CONFIG_DIR / font.name)
|
||||
try:
|
||||
return ImageFont.truetype(str(font) if font.exists() else font.name, size)
|
||||
except Exception: # download if missing
|
||||
try:
|
||||
check_font(font)
|
||||
return ImageFont.truetype(str(font), size)
|
||||
except TypeError:
|
||||
check_requirements('Pillow>=8.4.0') # known issue https://github.com/ultralytics/yolov5/issues/5374
|
||||
except URLError: # not online
|
||||
return ImageFont.load_default()
|
||||
|
||||
|
||||
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
||||
# Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
|
||||
xyxy = torch.tensor(xyxy).view(-1, 4)
|
||||
|
@ -85,8 +85,8 @@ def select_device(device='', batch=0, newline=False):
|
||||
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||
n = len(devices) # device count
|
||||
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
|
||||
raise ValueError(f'batch={batch} is not multiple of GPU count {n}.\n'
|
||||
f'Try batch={batch // n} or batch={batch // n + 1}')
|
||||
raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
|
||||
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}.")
|
||||
space = ' ' * (len(s) + 1)
|
||||
for i, d in enumerate(devices):
|
||||
p = torch.cuda.get_device_properties(i)
|
||||
|
@ -74,7 +74,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||
from ultralytics import YOLO
|
||||
YOLO(model)(**args)
|
||||
else:
|
||||
predictor = ClassificationPredictor(args)
|
||||
predictor = ClassificationPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
|
||||
|
||||
|
@ -146,7 +146,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
from ultralytics import YOLO
|
||||
YOLO(model).train(**args)
|
||||
else:
|
||||
trainer = ClassificationTrainer(args)
|
||||
trainer = ClassificationTrainer(overrides=args)
|
||||
trainer.train()
|
||||
|
||||
|
||||
|
@ -92,7 +92,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||
from ultralytics import YOLO
|
||||
YOLO(model)(**args)
|
||||
else:
|
||||
predictor = DetectionPredictor(args)
|
||||
predictor = DetectionPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
|
||||
|
||||
|
@ -204,7 +204,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
from ultralytics import YOLO
|
||||
YOLO(model).train(**args)
|
||||
else:
|
||||
trainer = DetectionTrainer(args)
|
||||
trainer = DetectionTrainer(overrides=args)
|
||||
trainer.train()
|
||||
|
||||
|
||||
|
@ -110,7 +110,7 @@ def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||
from ultralytics import YOLO
|
||||
YOLO(model)(**args)
|
||||
else:
|
||||
predictor = SegmentationPredictor(args)
|
||||
predictor = SegmentationPredictor(overrides=args)
|
||||
predictor.predict_cli()
|
||||
|
||||
|
||||
|
@ -150,7 +150,7 @@ def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
from ultralytics import YOLO
|
||||
YOLO(model).train(**args)
|
||||
else:
|
||||
trainer = SegmentationTrainer(args)
|
||||
trainer = SegmentationTrainer(overrides=args)
|
||||
trainer.train()
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user