mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 05:24:22 +08:00
ultralytics 8.0.44
export and task fixes (#1088)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
fe61018975
commit
3ea659411b
9
.github/workflows/ci.yaml
vendored
9
.github/workflows/ci.yaml
vendored
@ -32,9 +32,14 @@ jobs:
|
||||
# key: ${{ runner.os }}-Benchmarks-${{ hashFiles('requirements.txt') }}
|
||||
# restore-keys: ${{ runner.os }}-Benchmarks-
|
||||
- name: Install requirements
|
||||
shell: bash # for Windows compatibility
|
||||
run: |
|
||||
python -m pip install --upgrade pip wheel
|
||||
pip install -e . coremltools openvino-dev tensorflow-cpu paddlepaddle x2paddle --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||
pip install -e . coremltools openvino-dev tensorflow-macos --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
else
|
||||
pip install -e . coremltools openvino-dev tensorflow-cpu paddlepaddle x2paddle --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
fi
|
||||
yolo export format=tflite
|
||||
- name: Check environment
|
||||
run: |
|
||||
@ -94,6 +99,7 @@ jobs:
|
||||
key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }}
|
||||
restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-pip-
|
||||
- name: Install requirements
|
||||
shell: bash # for Windows compatibility
|
||||
run: |
|
||||
python -m pip install --upgrade pip wheel
|
||||
if [ "${{ matrix.torch }}" == "1.8.0" ]; then
|
||||
@ -101,7 +107,6 @@ jobs:
|
||||
else
|
||||
pip install -e '.[export]' pytest --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
fi
|
||||
shell: bash # for Windows compatibility
|
||||
- name: Check environment
|
||||
run: |
|
||||
echo "RUNNER_OS is ${{ runner.os }}"
|
||||
|
@ -78,13 +78,6 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"id": "ZOwTlorPd8-D"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
|
@ -3,7 +3,7 @@
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS
|
||||
from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS, checks
|
||||
|
||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
|
||||
CFG = 'yolov8n'
|
||||
@ -49,9 +49,10 @@ def test_val_classify():
|
||||
# Predict checks -------------------------------------------------------------------------------------------------------
|
||||
def test_predict_detect():
|
||||
run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
|
||||
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32')
|
||||
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32')
|
||||
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32')
|
||||
if checks.check_online():
|
||||
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32')
|
||||
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32')
|
||||
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32')
|
||||
|
||||
|
||||
def test_predict_segment():
|
||||
|
@ -9,7 +9,7 @@ from PIL import Image
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.yolo.data.build import load_inference_source
|
||||
from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS
|
||||
from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS, checks
|
||||
|
||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
||||
CFG = 'yolov8n.yaml'
|
||||
@ -49,28 +49,20 @@ def test_predict_dir():
|
||||
|
||||
def test_predict_img():
|
||||
model = YOLO(MODEL)
|
||||
output = model(source=Image.open(SOURCE), save=True, verbose=True) # PIL
|
||||
assert len(output) == 1, 'predict test failed'
|
||||
img = cv2.imread(str(SOURCE))
|
||||
output = model(source=img, save=True, save_txt=True) # ndarray
|
||||
assert len(output) == 1, 'predict test failed'
|
||||
output = model(source=[img, img], save=True, save_txt=True) # batch
|
||||
assert len(output) == 2, 'predict test failed'
|
||||
output = model(source=[img, img], save=True, stream=True) # stream
|
||||
assert len(list(output)) == 2, 'predict test failed'
|
||||
tens = torch.zeros(320, 640, 3)
|
||||
output = model(tens.numpy())
|
||||
assert len(output) == 1, 'predict test failed'
|
||||
# test multiple source
|
||||
imgs = [
|
||||
SOURCE, # filename
|
||||
im = cv2.imread(str(SOURCE))
|
||||
assert len(model(source=Image.open(SOURCE), save=True, verbose=True)) == 1 # PIL
|
||||
assert len(model(source=im, save=True, save_txt=True)) == 1 # ndarray
|
||||
assert len(model(source=[im, im], save=True, save_txt=True)) == 2 # batch
|
||||
assert len(list(model(source=[im, im], save=True, stream=True))) == 2 # stream
|
||||
assert len(model(torch.zeros(320, 640, 3).numpy())) == 1 # tensor to numpy
|
||||
batch = [
|
||||
str(SOURCE), # filename
|
||||
Path(SOURCE), # Path
|
||||
'https://ultralytics.com/images/zidane.jpg', # URI
|
||||
'https://ultralytics.com/images/zidane.jpg' if checks.check_online() else SOURCE, # URI
|
||||
cv2.imread(str(SOURCE)), # OpenCV
|
||||
Image.open(SOURCE), # PIL
|
||||
np.zeros((320, 640, 3))] # numpy
|
||||
output = model(imgs)
|
||||
assert len(output) == 6, 'predict test failed!'
|
||||
assert len(model(batch)) == len(batch) # multiple sources in a batch
|
||||
|
||||
|
||||
def test_predict_grey_and_4ch():
|
||||
@ -85,6 +77,11 @@ def test_val():
|
||||
model.val(data='coco8.yaml', imgsz=32)
|
||||
|
||||
|
||||
def test_val_scratch():
|
||||
model = YOLO(CFG)
|
||||
model.val(data='coco8.yaml', imgsz=32)
|
||||
|
||||
|
||||
def test_train_scratch():
|
||||
model = YOLO(CFG)
|
||||
model.train(data='coco8.yaml', epochs=1, imgsz=32)
|
||||
@ -103,6 +100,12 @@ def test_export_torchscript():
|
||||
YOLO(f)(SOURCE) # exported model inference
|
||||
|
||||
|
||||
def test_export_torchscript_scratch():
|
||||
model = YOLO(CFG)
|
||||
f = model.export(format='torchscript')
|
||||
YOLO(f)(SOURCE) # exported model inference
|
||||
|
||||
|
||||
def test_export_onnx():
|
||||
model = YOLO(MODEL)
|
||||
f = model.export(format='onnx')
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
__version__ = '8.0.43'
|
||||
__version__ = '8.0.44'
|
||||
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
from ultralytics.yolo.utils.checks import check_yolo as checks
|
||||
|
@ -15,7 +15,7 @@ EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ['ultralytics_tflite', 'ultralytics_c
|
||||
|
||||
def start(key=''):
|
||||
"""
|
||||
Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY')
|
||||
Start training models with Ultralytics HUB. Usage: from ultralytics.hub import start; start('API_KEY')
|
||||
"""
|
||||
auth = Auth(key)
|
||||
try:
|
||||
@ -30,9 +30,9 @@ def start(key=''):
|
||||
session = HubTrainingSession(model_id=model_id, auth=auth)
|
||||
session.check_disk_space()
|
||||
|
||||
trainer = YOLO(session.input_file)
|
||||
session.register_callbacks(trainer)
|
||||
trainer.train(**session.train_args)
|
||||
model = YOLO(session.input_file)
|
||||
session.register_callbacks(model)
|
||||
model.train(**session.train_args)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'{PREFIX}{e}')
|
||||
|
||||
@ -93,6 +93,5 @@ def get_export(key='', format='torchscript'):
|
||||
return r.json()
|
||||
|
||||
|
||||
# temp. For checking
|
||||
if __name__ == '__main__':
|
||||
start()
|
||||
|
@ -26,6 +26,7 @@ class HubTrainingSession:
|
||||
self._timers = {} # rate limit timers (seconds)
|
||||
self._metrics_queue = {} # metrics queue
|
||||
self.model = self._get_model()
|
||||
self.alive = True
|
||||
self._start_heartbeat() # start heartbeats
|
||||
self._register_signal_handlers()
|
||||
|
||||
@ -52,37 +53,6 @@ class HubTrainingSession:
|
||||
payload = {'metrics': self._metrics_queue.copy(), 'type': 'metrics'}
|
||||
smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2)
|
||||
|
||||
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
||||
# Upload a model to HUB
|
||||
file = None
|
||||
if Path(weights).is_file():
|
||||
with open(weights, 'rb') as f:
|
||||
file = f.read()
|
||||
if final:
|
||||
smart_request(
|
||||
f'{self.api_url}/upload',
|
||||
data={
|
||||
'epoch': epoch,
|
||||
'type': 'final',
|
||||
'map': map},
|
||||
files={'best.pt': file},
|
||||
headers=self.auth_header,
|
||||
retry=10,
|
||||
timeout=3600,
|
||||
code=4,
|
||||
)
|
||||
else:
|
||||
smart_request(
|
||||
f'{self.api_url}/upload',
|
||||
data={
|
||||
'epoch': epoch,
|
||||
'type': 'epoch',
|
||||
'isBest': bool(is_best)},
|
||||
headers=self.auth_header,
|
||||
files={'last.pt': file},
|
||||
code=3,
|
||||
)
|
||||
|
||||
def _get_model(self):
|
||||
# Returns model from database by id
|
||||
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
|
||||
@ -151,7 +121,7 @@ class HubTrainingSession:
|
||||
model_info = {
|
||||
'model/parameters': get_num_params(trainer.model),
|
||||
'model/GFLOPs': round(get_flops(trainer.model), 3),
|
||||
'model/speed(ms)': round(trainer.validator.speed[1], 3)}
|
||||
'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
|
||||
all_plots = {**all_plots, **model_info}
|
||||
self._metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
||||
if time() - self._timers['metrics'] > self._rate_limits['metrics']:
|
||||
@ -169,52 +139,45 @@ class HubTrainingSession:
|
||||
|
||||
def on_train_end(self, trainer):
|
||||
# Upload final model and metrics with exponential standoff
|
||||
LOGGER.info(f'{PREFIX}Training completed successfully ✅')
|
||||
LOGGER.info(f'{PREFIX}Uploading final {self.model_id}')
|
||||
LOGGER.info(f'{PREFIX}Training completed successfully ✅\n'
|
||||
f'{PREFIX}Uploading final {self.model_id}')
|
||||
|
||||
# hack for fetching mAP
|
||||
mAP = trainer.metrics.get('metrics/mAP50-95(B)', 0)
|
||||
self._upload_model(trainer.epoch, trainer.best, map=mAP, final=True) # results[3] is mAP0.5:0.95
|
||||
self._upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
|
||||
self.alive = False # stop heartbeats
|
||||
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
|
||||
|
||||
def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
||||
# Upload a model to HUB
|
||||
file = None
|
||||
if Path(weights).is_file():
|
||||
with open(weights, 'rb') as f:
|
||||
file = f.read()
|
||||
file_param = {'best.pt' if final else 'last.pt': file}
|
||||
endpoint = f'{self.api_url}/upload'
|
||||
else:
|
||||
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload failed. Missing model {weights}.')
|
||||
file = None
|
||||
data = {'epoch': epoch}
|
||||
if final:
|
||||
data.update({'type': 'final', 'map': map})
|
||||
else:
|
||||
data.update({'type': 'epoch', 'isBest': bool(is_best)})
|
||||
|
||||
smart_request(
|
||||
endpoint,
|
||||
data=data,
|
||||
files=file_param,
|
||||
headers=self.auth_header,
|
||||
retry=10 if final else None,
|
||||
timeout=3600 if final else None,
|
||||
code=4 if final else 3,
|
||||
)
|
||||
smart_request(f'{self.api_url}/upload',
|
||||
data=data,
|
||||
files={'best.pt' if final else 'last.pt': file},
|
||||
headers=self.auth_header,
|
||||
retry=10 if final else None,
|
||||
timeout=3600 if final else None,
|
||||
code=4 if final else 3)
|
||||
|
||||
@threaded
|
||||
def _start_heartbeat(self):
|
||||
self.alive = True
|
||||
while self.alive:
|
||||
r = smart_request(
|
||||
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
|
||||
json={
|
||||
'agent': AGENT_NAME,
|
||||
'agentId': self.agent_id},
|
||||
headers=self.auth_header,
|
||||
retry=0,
|
||||
code=5,
|
||||
thread=False,
|
||||
)
|
||||
r = smart_request(f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
|
||||
json={
|
||||
'agent': AGENT_NAME,
|
||||
'agentId': self.agent_id},
|
||||
headers=self.auth_header,
|
||||
retry=0,
|
||||
code=5,
|
||||
thread=False)
|
||||
self.agent_id = r.json().get('data', {}).get('agentId', None)
|
||||
sleep(self._rate_limits['heartbeat'])
|
||||
|
@ -181,7 +181,6 @@ class AutoBackend(nn.Module):
|
||||
import tensorflow as tf
|
||||
keras = False # assume TF1 saved_model
|
||||
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
||||
w = Path(w) / 'metadata.yaml'
|
||||
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
||||
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
||||
import tensorflow as tf
|
||||
@ -258,8 +257,9 @@ class AutoBackend(nn.Module):
|
||||
f'\n\n{EXPORT_FORMATS_TABLE}')
|
||||
|
||||
# Load external metadata YAML
|
||||
w = Path(w)
|
||||
if xml or saved_model or paddle:
|
||||
metadata = Path(w).parent / 'metadata.yaml'
|
||||
metadata = (w if saved_model else w.parents[1] if paddle else w.parent) / 'metadata.yaml'
|
||||
if metadata.exists():
|
||||
metadata = yaml_load(metadata)
|
||||
stride, names = int(metadata['stride']), metadata['names'] # load metadata
|
||||
|
@ -287,6 +287,7 @@ class ClassificationModel(BaseModel):
|
||||
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
||||
self.yaml['nc'] = nc # override yaml value
|
||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
||||
self.stride = torch.Tensor([1]) # no stride constraints
|
||||
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
||||
self.info()
|
||||
|
||||
@ -520,14 +521,15 @@ def guess_model_task(model):
|
||||
|
||||
# Guess from model filename
|
||||
if isinstance(model, (str, Path)):
|
||||
model = Path(model).stem
|
||||
if '-seg' in model:
|
||||
model = Path(model)
|
||||
if '-seg' in model.stem or 'segment' in model.parts:
|
||||
return 'segment'
|
||||
elif '-cls' in model:
|
||||
elif '-cls' in model.stem or 'classify' in model.parts:
|
||||
return 'classify'
|
||||
else:
|
||||
elif 'detect' in model.parts:
|
||||
return 'detect'
|
||||
|
||||
# Unable to determine task from model
|
||||
raise SyntaxError('YOLO is unable to automatically guess model task. Explicitly define task for your model, '
|
||||
"i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
||||
LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
|
||||
"Explicitly define task for your model, i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
||||
return 'detect' # assume detect
|
||||
|
@ -47,11 +47,12 @@ CLI_HELP_MSG = \
|
||||
GitHub: https://github.com/ultralytics/ultralytics
|
||||
"""
|
||||
|
||||
CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'}
|
||||
# Define keys for arg type checks
|
||||
CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'fl_gamma'}
|
||||
CFG_FRACTION_KEYS = {
|
||||
'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'fl_gamma',
|
||||
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic',
|
||||
'mixup', 'copy_paste', 'conf', 'iou'}
|
||||
'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'label_smoothing',
|
||||
'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic', 'mixup', 'copy_paste',
|
||||
'conf', 'iou'} # fractional floats limited to 0.0 - 1.0
|
||||
CFG_INT_KEYS = {
|
||||
'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
||||
'line_thickness', 'workspace', 'nbs', 'save_period'}
|
||||
@ -224,7 +225,7 @@ def entrypoint(debug=''):
|
||||
assert v, f"missing '{k}' value"
|
||||
if k == 'cfg': # custom.yaml passed
|
||||
LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
|
||||
overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'}
|
||||
overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
|
||||
else:
|
||||
if v.lower() == 'none':
|
||||
v = None
|
||||
@ -255,7 +256,6 @@ def entrypoint(debug=''):
|
||||
check_cfg_mismatch(full_args_dict, {a: ''})
|
||||
|
||||
# Defaults
|
||||
task2model = dict(detect='yolov8n.pt', segment='yolov8n-seg.pt', classify='yolov8n-cls.pt')
|
||||
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100')
|
||||
|
||||
# Mode
|
||||
@ -272,27 +272,28 @@ def entrypoint(debug=''):
|
||||
|
||||
# Model
|
||||
model = overrides.pop('model', DEFAULT_CFG.model)
|
||||
task = overrides.pop('task', None)
|
||||
if model is None:
|
||||
model = task2model.get(task, 'yolov8n.pt')
|
||||
model = 'yolov8n.pt'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
||||
from ultralytics.yolo.engine.model import YOLO
|
||||
overrides['model'] = model
|
||||
model = YOLO(model)
|
||||
|
||||
# Task
|
||||
if task and task != model.task:
|
||||
LOGGER.warning(f"WARNING ⚠️ 'task={task}' conflicts with {model.task} model {overrides['model']}. "
|
||||
f"Inheriting 'task={model.task}' from {overrides['model']} and ignoring 'task={task}'.")
|
||||
task = model.task
|
||||
overrides['task'] = task
|
||||
# if task and task != model.task:
|
||||
# LOGGER.warning(f"WARNING ⚠️ 'task={task}' conflicts with {model.task} model {overrides['model']}. "
|
||||
# f"Inheriting 'task={model.task}' from {overrides['model']} and ignoring 'task={task}'.")
|
||||
overrides['task'] = overrides.get('task', model.task)
|
||||
model.task = overrides['task']
|
||||
|
||||
# Mode
|
||||
if mode in {'predict', 'track'} and 'source' not in overrides:
|
||||
overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||
else 'https://ultralytics.com/images/bus.jpg'
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
||||
elif mode in ('train', 'val'):
|
||||
if 'data' not in overrides:
|
||||
overrides['data'] = task2data.get(task, DEFAULT_CFG.data)
|
||||
overrides['data'] = task2data.get(overrides['task'], DEFAULT_CFG.data)
|
||||
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using {model.task} default 'data={overrides['data']}'.")
|
||||
elif mode == 'export':
|
||||
if 'format' not in overrides:
|
||||
|
@ -6,7 +6,6 @@ Dataloaders and dataset utils
|
||||
import contextlib
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
@ -27,11 +26,9 @@ from PIL import ExifTags, Image, ImageOps
|
||||
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.data.utils import check_det_dataset
|
||||
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
|
||||
is_kaggle, yaml_load)
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import unzip_file
|
||||
is_kaggle)
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
|
||||
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
||||
|
||||
@ -1037,127 +1034,6 @@ def verify_image_label(args):
|
||||
return [None, None, None, None, nm, nf, ne, nc, msg]
|
||||
|
||||
|
||||
class HUBDatasetStats():
|
||||
""" Class for generating HUB dataset JSON and `-hub` dataset directory
|
||||
|
||||
Arguments
|
||||
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
|
||||
autodownload: Attempt to download dataset if not found locally
|
||||
|
||||
Usage
|
||||
from ultralytics.yolo.data.dataloaders.v5loader import HUBDatasetStats
|
||||
stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1
|
||||
stats = HUBDatasetStats('path/to/coco128.zip') # usage 2
|
||||
stats.get_json(save=False)
|
||||
stats.process_images()
|
||||
"""
|
||||
|
||||
def __init__(self, path='coco128.yaml', autodownload=False):
|
||||
# Initialize class
|
||||
zipped, data_dir, yaml_path = self._unzip(Path(path))
|
||||
# try:
|
||||
# data = yaml_load(check_yaml(yaml_path)) # data dict
|
||||
# if zipped:
|
||||
# data['path'] = data_dir
|
||||
# except Exception as e:
|
||||
# raise Exception('error/HUB/dataset_stats/yaml_load') from e
|
||||
|
||||
data = check_det_dataset(yaml_path, autodownload) # download dataset if missing
|
||||
self.hub_dir = Path(str(data['path']) + '-hub')
|
||||
self.im_dir = self.hub_dir / 'images'
|
||||
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
||||
self.stats = {'nc': data['nc'], 'names': list(data['names'].values())} # statistics dictionary
|
||||
self.data = data
|
||||
|
||||
@staticmethod
|
||||
def _find_yaml(dir):
|
||||
# Return data.yaml file
|
||||
files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
|
||||
assert files, f'No *.yaml file found in {dir}'
|
||||
if len(files) > 1:
|
||||
files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
|
||||
assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
|
||||
assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
|
||||
return files[0]
|
||||
|
||||
def _unzip(self, path):
|
||||
# Unzip data.zip
|
||||
if not str(path).endswith('.zip'): # path is data.yaml
|
||||
return False, None, path
|
||||
assert Path(path).is_file(), f'Error unzipping {path}, file not found'
|
||||
unzip_file(path, path=path.parent)
|
||||
dir = path.with_suffix('') # dataset directory == zip name
|
||||
assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
|
||||
return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
|
||||
|
||||
def _hub_ops(self, f, max_dim=1920):
|
||||
# HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
|
||||
f_new = self.im_dir / Path(f).name # dataset-hub image filename
|
||||
try: # use PIL
|
||||
im = Image.open(f)
|
||||
r = max_dim / max(im.height, im.width) # ratio
|
||||
if r < 1.0: # image too large
|
||||
im = im.resize((int(im.width * r), int(im.height * r)))
|
||||
im.save(f_new, 'JPEG', quality=50, optimize=True) # save
|
||||
except Exception as e: # use OpenCV
|
||||
LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
|
||||
im = cv2.imread(f)
|
||||
im_height, im_width = im.shape[:2]
|
||||
r = max_dim / max(im_height, im_width) # ratio
|
||||
if r < 1.0: # image too large
|
||||
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
|
||||
cv2.imwrite(str(f_new), im)
|
||||
|
||||
def get_json(self, save=False, verbose=False):
|
||||
# Return dataset JSON for Ultralytics HUB
|
||||
def _round(labels):
|
||||
# Update labels to integer class and 6 decimal place floats
|
||||
return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
|
||||
|
||||
for split in 'train', 'val', 'test':
|
||||
if self.data.get(split) is None:
|
||||
self.stats[split] = None # i.e. no test set
|
||||
continue
|
||||
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
||||
x = np.array([
|
||||
np.bincount(label[:, 0].astype(int), minlength=self.data['nc'])
|
||||
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics')]) # shape(128x80)
|
||||
self.stats[split] = {
|
||||
'instance_stats': {
|
||||
'total': int(x.sum()),
|
||||
'per_class': x.sum(0).tolist()},
|
||||
'image_stats': {
|
||||
'total': dataset.n,
|
||||
'unlabelled': int(np.all(x == 0, 1).sum()),
|
||||
'per_class': (x > 0).sum(0).tolist()},
|
||||
'labels': [{
|
||||
str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
|
||||
|
||||
# Save, print and return
|
||||
if save:
|
||||
stats_path = self.hub_dir / 'stats.json'
|
||||
LOGGER.info(f'Saving {stats_path.resolve()}...')
|
||||
with open(stats_path, 'w') as f:
|
||||
json.dump(self.stats, f) # save stats.json
|
||||
if verbose:
|
||||
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
|
||||
return self.stats
|
||||
|
||||
def process_images(self):
|
||||
# Compress images for Ultralytics HUB
|
||||
for split in 'train', 'val', 'test':
|
||||
if self.data.get(split) is None:
|
||||
continue
|
||||
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
||||
desc = f'{split} images'
|
||||
total = dataset.n
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=total, desc=desc):
|
||||
pass
|
||||
LOGGER.info(f'Done. All images saved to {self.im_dir}')
|
||||
return self.im_dir
|
||||
|
||||
|
||||
# Classification dataloaders -------------------------------------------------------------------------------------------
|
||||
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
"""
|
||||
|
@ -2,9 +2,11 @@
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from tarfile import is_tarfile
|
||||
from zipfile import is_zipfile
|
||||
@ -12,10 +14,11 @@ from zipfile import is_zipfile
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import ExifTags, Image, ImageOps
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, emojis, yaml_load
|
||||
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, colorstr, emojis, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
|
||||
from ultralytics.yolo.utils.downloads import download, safe_download
|
||||
from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file
|
||||
from ultralytics.yolo.utils.ops import segments2boxes
|
||||
|
||||
HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
||||
@ -290,3 +293,128 @@ def check_cls_dataset(dataset: str):
|
||||
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
||||
names = dict(enumerate(sorted(names)))
|
||||
return {'train': train_set, 'val': test_set, 'nc': nc, 'names': names}
|
||||
|
||||
|
||||
class HUBDatasetStats():
|
||||
""" Class for generating HUB dataset JSON and `-hub` dataset directory
|
||||
|
||||
Arguments
|
||||
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
|
||||
autodownload: Attempt to download dataset if not found locally
|
||||
|
||||
Usage
|
||||
from ultralytics.yolo.data.utils import HUBDatasetStats
|
||||
stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1
|
||||
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco6.zip') # usage 2
|
||||
stats.get_json(save=False)
|
||||
stats.process_images()
|
||||
"""
|
||||
|
||||
def __init__(self, path='coco128.yaml', autodownload=False):
|
||||
# Initialize class
|
||||
zipped, data_dir, yaml_path = self._unzip(Path(path))
|
||||
try:
|
||||
# data = yaml_load(check_yaml(yaml_path)) # data dict
|
||||
data = check_det_dataset(yaml_path, autodownload) # data dict
|
||||
if zipped:
|
||||
data['path'] = data_dir
|
||||
except Exception as e:
|
||||
raise Exception('error/HUB/dataset_stats/yaml_load') from e
|
||||
|
||||
self.hub_dir = Path(str(data['path']) + '-hub')
|
||||
self.im_dir = self.hub_dir / 'images'
|
||||
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
||||
self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
|
||||
self.data = data
|
||||
|
||||
@staticmethod
|
||||
def _find_yaml(dir):
|
||||
# Return data.yaml file
|
||||
files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
|
||||
assert files, f'No *.yaml file found in {dir}'
|
||||
if len(files) > 1:
|
||||
files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
|
||||
assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
|
||||
assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
|
||||
return files[0]
|
||||
|
||||
def _unzip(self, path):
|
||||
# Unzip data.zip
|
||||
if not str(path).endswith('.zip'): # path is data.yaml
|
||||
return False, None, path
|
||||
assert Path(path).is_file(), f'Error unzipping {path}, file not found'
|
||||
unzip_file(path, path=path.parent)
|
||||
dir = path.with_suffix('') # dataset directory == zip name
|
||||
assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
|
||||
return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
|
||||
|
||||
def _hub_ops(self, f, max_dim=1920):
|
||||
# HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
|
||||
f_new = self.im_dir / Path(f).name # dataset-hub image filename
|
||||
try: # use PIL
|
||||
im = Image.open(f)
|
||||
r = max_dim / max(im.height, im.width) # ratio
|
||||
if r < 1.0: # image too large
|
||||
im = im.resize((int(im.width * r), int(im.height * r)))
|
||||
im.save(f_new, 'JPEG', quality=50, optimize=True) # save
|
||||
except Exception as e: # use OpenCV
|
||||
LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
|
||||
im = cv2.imread(f)
|
||||
im_height, im_width = im.shape[:2]
|
||||
r = max_dim / max(im_height, im_width) # ratio
|
||||
if r < 1.0: # image too large
|
||||
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
|
||||
cv2.imwrite(str(f_new), im)
|
||||
|
||||
def get_json(self, save=False, verbose=False):
|
||||
# Return dataset JSON for Ultralytics HUB
|
||||
# from ultralytics.yolo.data import YOLODataset
|
||||
from ultralytics.yolo.data.dataloaders.v5loader import LoadImagesAndLabels
|
||||
|
||||
def _round(labels):
|
||||
# Update labels to integer class and 6 decimal place floats
|
||||
return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
|
||||
|
||||
for split in 'train', 'val', 'test':
|
||||
if self.data.get(split) is None:
|
||||
self.stats[split] = None # i.e. no test set
|
||||
continue
|
||||
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
||||
x = np.array([
|
||||
np.bincount(label[:, 0].astype(int), minlength=self.data['nc'])
|
||||
for label in tqdm(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
|
||||
self.stats[split] = {
|
||||
'instance_stats': {
|
||||
'total': int(x.sum()),
|
||||
'per_class': x.sum(0).tolist()},
|
||||
'image_stats': {
|
||||
'total': len(dataset),
|
||||
'unlabelled': int(np.all(x == 0, 1).sum()),
|
||||
'per_class': (x > 0).sum(0).tolist()},
|
||||
'labels': [{
|
||||
str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
|
||||
|
||||
# Save, print and return
|
||||
if save:
|
||||
stats_path = self.hub_dir / 'stats.json'
|
||||
LOGGER.info(f'Saving {stats_path.resolve()}...')
|
||||
with open(stats_path, 'w') as f:
|
||||
json.dump(self.stats, f) # save stats.json
|
||||
if verbose:
|
||||
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
|
||||
return self.stats
|
||||
|
||||
def process_images(self):
|
||||
# Compress images for Ultralytics HUB
|
||||
# from ultralytics.yolo.data import YOLODataset
|
||||
from ultralytics.yolo.data.dataloaders.v5loader import LoadImagesAndLabels
|
||||
|
||||
for split in 'train', 'val', 'test':
|
||||
if self.data.get(split) is None:
|
||||
continue
|
||||
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
|
||||
pass
|
||||
LOGGER.info(f'Done. All images saved to {self.im_dir}')
|
||||
return self.im_dir
|
||||
|
@ -208,12 +208,15 @@ class Exporter:
|
||||
self.file = file
|
||||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
||||
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
|
||||
description = f'Ultralytics {self.pretty_name} model' + f'trained on {Path(self.args.data).name}' \
|
||||
if self.args.data else '(untrained)'
|
||||
self.metadata = {
|
||||
'description': f'Ultralytics {self.pretty_name} model trained on {Path(self.args.data).name}',
|
||||
'description': description,
|
||||
'author': 'Ultralytics',
|
||||
'license': 'GPL-3.0 https://ultralytics.com/license',
|
||||
'version': __version__,
|
||||
'stride': int(max(model.stride)),
|
||||
'task': model.task,
|
||||
'names': model.names} # model metadata
|
||||
|
||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
|
||||
|
@ -9,76 +9,72 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
|
||||
guess_model_task, nn)
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, yaml_load
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
|
||||
# Map head to model, trainer, validator, and predictor classes
|
||||
MODEL_MAP = {
|
||||
TASK_MAP = {
|
||||
'classify': [
|
||||
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
|
||||
'yolo.TYPE.classify.ClassificationPredictor'],
|
||||
ClassificationModel, yolo.v8.classify.ClassificationTrainer, yolo.v8.classify.ClassificationValidator,
|
||||
yolo.v8.classify.ClassificationPredictor],
|
||||
'detect': [
|
||||
DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
|
||||
'yolo.TYPE.detect.DetectionPredictor'],
|
||||
DetectionModel, yolo.v8.detect.DetectionTrainer, yolo.v8.detect.DetectionValidator,
|
||||
yolo.v8.detect.DetectionPredictor],
|
||||
'segment': [
|
||||
SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
|
||||
'yolo.TYPE.segment.SegmentationPredictor']}
|
||||
SegmentationModel, yolo.v8.segment.SegmentationTrainer, yolo.v8.segment.SegmentationValidator,
|
||||
yolo.v8.segment.SegmentationPredictor]}
|
||||
|
||||
|
||||
class YOLO:
|
||||
"""
|
||||
YOLO (You Only Look Once) object detection model.
|
||||
YOLO (You Only Look Once) object detection model.
|
||||
|
||||
Args:
|
||||
model (str, Path): Path to the model file to load or create.
|
||||
type (str): Type/version of models to use. Defaults to "v8".
|
||||
Args:
|
||||
model (str, Path): Path to the model file to load or create.
|
||||
|
||||
Attributes:
|
||||
type (str): Type/version of models being used.
|
||||
ModelClass (Any): Model class.
|
||||
TrainerClass (Any): Trainer class.
|
||||
ValidatorClass (Any): Validator class.
|
||||
PredictorClass (Any): Predictor class.
|
||||
predictor (Any): Predictor object.
|
||||
model (Any): Model object.
|
||||
trainer (Any): Trainer object.
|
||||
task (str): Type of model task.
|
||||
ckpt (Any): Checkpoint object if model loaded from *.pt file.
|
||||
cfg (str): Model configuration if loaded from *.yaml file.
|
||||
ckpt_path (str): Checkpoint file path.
|
||||
overrides (dict): Overrides for trainer object.
|
||||
metrics_data (Any): Data for metrics.
|
||||
Attributes:
|
||||
predictor (Any): The predictor object.
|
||||
model (Any): The model object.
|
||||
trainer (Any): The trainer object.
|
||||
task (str): The type of model task.
|
||||
ckpt (Any): The checkpoint object if the model loaded from *.pt file.
|
||||
cfg (str): The model configuration if loaded from *.yaml file.
|
||||
ckpt_path (str): The checkpoint file path.
|
||||
overrides (dict): Overrides for the trainer object.
|
||||
metrics_data (Any): The data for metrics.
|
||||
|
||||
Methods:
|
||||
__call__(): Alias for predict method.
|
||||
_new(cfg, verbose=True): Initializes a new model and infers the task type from the model definitions.
|
||||
_load(weights): Initializes a new model and infers the task type from the model head.
|
||||
_check_is_pytorch_model(): Raises TypeError if model is not a PyTorch model.
|
||||
reset(): Resets the model modules.
|
||||
info(verbose=False): Logs model info.
|
||||
fuse(): Fuse model for faster inference.
|
||||
predict(source=None, stream=False, **kwargs): Perform prediction using the YOLO model.
|
||||
Methods:
|
||||
__call__(source=None, stream=False, **kwargs):
|
||||
Alias for the predict method.
|
||||
_new(cfg:str, verbose:bool=True) -> None:
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
_load(weights:str, task:str='') -> None:
|
||||
Initializes a new model and infers the task type from the model head.
|
||||
_check_is_pytorch_model() -> None:
|
||||
Raises TypeError if the model is not a PyTorch model.
|
||||
reset() -> None:
|
||||
Resets the model modules.
|
||||
info(verbose:bool=False) -> None:
|
||||
Logs the model info.
|
||||
fuse() -> None:
|
||||
Fuses the model for faster inference.
|
||||
predict(source=None, stream=False, **kwargs) -> List[ultralytics.yolo.engine.results.Results]:
|
||||
Performs prediction using the YOLO model.
|
||||
|
||||
Returns:
|
||||
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
||||
"""
|
||||
Returns:
|
||||
list[ultralytics.yolo.engine.results.Results]: The prediction results.
|
||||
"""
|
||||
|
||||
def __init__(self, model='yolov8n.pt', type='v8') -> None:
|
||||
def __init__(self, model='yolov8n.pt') -> None:
|
||||
"""
|
||||
Initializes the YOLO model.
|
||||
|
||||
Args:
|
||||
model (str, Path): model to load or create
|
||||
type (str): Type/version of models to use. Defaults to "v8".
|
||||
"""
|
||||
self._reset_callbacks()
|
||||
self.type = type
|
||||
self.ModelClass = None # model class
|
||||
self.TrainerClass = None # trainer class
|
||||
self.ValidatorClass = None # validator class
|
||||
self.PredictorClass = None # predictor class
|
||||
self.predictor = None # reuse predictor
|
||||
self.model = None # model object
|
||||
self.trainer = None # trainer object
|
||||
@ -101,6 +97,10 @@ class YOLO:
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
def _new(self, cfg: str, verbose=True):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
@ -112,11 +112,15 @@ class YOLO:
|
||||
self.cfg = check_yaml(cfg) # check YAML
|
||||
cfg_dict = yaml_load(self.cfg, append_filename=True) # model dict
|
||||
self.task = guess_model_task(cfg_dict)
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
||||
self.model = self.ModelClass(cfg_dict, verbose=verbose and RANK == -1) # initialize
|
||||
self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model
|
||||
self.overrides['model'] = self.cfg
|
||||
|
||||
def _load(self, weights: str):
|
||||
# Below added to allow export from yamls
|
||||
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
|
||||
self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||
self.model.task = self.task
|
||||
|
||||
def _load(self, weights: str, task=''):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model head.
|
||||
|
||||
@ -127,8 +131,7 @@ class YOLO:
|
||||
if suffix == '.pt':
|
||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.task = self.model.args['task']
|
||||
self.overrides = self.model.args
|
||||
self._reset_ckpt_args(self.overrides)
|
||||
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
||||
self.ckpt_path = self.model.pt_path
|
||||
else:
|
||||
weights = check_file(weights)
|
||||
@ -136,7 +139,6 @@ class YOLO:
|
||||
self.task = guess_model_task(weights)
|
||||
self.ckpt_path = weights
|
||||
self.overrides['model'] = weights
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
||||
|
||||
def _check_is_pytorch_model(self):
|
||||
"""
|
||||
@ -189,12 +191,13 @@ class YOLO:
|
||||
"""
|
||||
overrides = self.overrides.copy()
|
||||
overrides['conf'] = 0.25
|
||||
overrides.update(kwargs)
|
||||
overrides.update(kwargs) # prefer kwargs
|
||||
overrides['mode'] = kwargs.get('mode', 'predict')
|
||||
assert overrides['mode'] in ['track', 'predict']
|
||||
overrides['save'] = kwargs.get('save', False) # not save files by default
|
||||
if not self.predictor:
|
||||
self.predictor = self.PredictorClass(overrides=overrides)
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
|
||||
self.predictor.setup_model(model=self.model)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||
@ -226,12 +229,15 @@ class YOLO:
|
||||
overrides['mode'] = 'val'
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.data = data or args.data
|
||||
args.task = self.task
|
||||
if 'task' in overrides:
|
||||
self.task = args.task
|
||||
else:
|
||||
args.task = self.task
|
||||
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
||||
|
||||
validator = self.ValidatorClass(args=args)
|
||||
validator = TASK_MAP[self.task][2](args=args)
|
||||
validator(model=self.model)
|
||||
self.metrics_data = validator.metrics
|
||||
|
||||
@ -267,8 +273,7 @@ class YOLO:
|
||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if args.batch == DEFAULT_CFG.batch:
|
||||
args.batch = 1 # default to 1 if not modified
|
||||
exporter = Exporter(overrides=args)
|
||||
return exporter(model=self.model)
|
||||
return Exporter(overrides=args)(model=self.model)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
@ -282,15 +287,15 @@ class YOLO:
|
||||
overrides.update(kwargs)
|
||||
if kwargs.get('cfg'):
|
||||
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
||||
overrides = yaml_load(check_yaml(kwargs['cfg']), append_filename=True)
|
||||
overrides['task'] = self.task
|
||||
overrides = yaml_load(check_yaml(kwargs['cfg']))
|
||||
overrides['mode'] = 'train'
|
||||
if not overrides.get('data'):
|
||||
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
||||
if overrides.get('resume'):
|
||||
overrides['resume'] = self.ckpt_path
|
||||
|
||||
self.trainer = self.TrainerClass(overrides=overrides)
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.trainer = TASK_MAP[self.task][1](overrides=overrides)
|
||||
if not overrides.get('resume'): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||
self.model = self.trainer.model
|
||||
@ -311,13 +316,6 @@ class YOLO:
|
||||
self._check_is_pytorch_model()
|
||||
self.model.to(device)
|
||||
|
||||
def _assign_ops_from_task(self):
|
||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
|
||||
trainer_class = eval(train_lit.replace('TYPE', f'{self.type}'))
|
||||
validator_class = eval(val_lit.replace('TYPE', f'{self.type}'))
|
||||
predictor_class = eval(pred_lit.replace('TYPE', f'{self.type}'))
|
||||
return model_class, trainer_class, validator_class, predictor_class
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
"""
|
||||
@ -357,9 +355,8 @@ class YOLO:
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
|
||||
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset', 'simplify':
|
||||
args.pop(arg, None)
|
||||
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
||||
return {k: v for k, v in args.items() if k in include}
|
||||
|
||||
@staticmethod
|
||||
def _reset_callbacks():
|
||||
|
@ -108,7 +108,6 @@ class BasePredictor:
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
return preds
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, source=None, model=None, stream=False):
|
||||
if stream:
|
||||
return self.stream_inference(source, model)
|
||||
@ -136,6 +135,7 @@ class BasePredictor:
|
||||
self.source_type = self.dataset.source_type
|
||||
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
|
||||
|
||||
@smart_inference_mode()
|
||||
def stream_inference(self, source=None, model=None):
|
||||
if self.args.verbose:
|
||||
LOGGER.info('')
|
||||
@ -161,12 +161,14 @@ class BasePredictor:
|
||||
self.batch = batch
|
||||
path, im, im0s, vid_cap, s = batch
|
||||
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
||||
|
||||
# preprocess
|
||||
with self.dt[0]:
|
||||
im = self.preprocess(im)
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
|
||||
# Inference
|
||||
# inference
|
||||
with self.dt[1]:
|
||||
preds = self.model(im, augment=self.args.augment, visualize=visualize)
|
||||
|
||||
|
@ -18,29 +18,33 @@ from ultralytics.yolo.utils.plotting import Annotator, colors
|
||||
|
||||
class Results:
|
||||
"""
|
||||
A class for storing and manipulating inference results.
|
||||
A class for storing and manipulating inference results.
|
||||
|
||||
Args:
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
||||
orig_img (tuple, optional): Original image size.
|
||||
Args:
|
||||
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||
path (str): The path to the image file.
|
||||
names (List[str]): A list of class names.
|
||||
boxes (List[List[float]], optional): A list of bounding box coordinates for each detection.
|
||||
masks (numpy.ndarray, optional): A 3D numpy array of detection masks, where each mask is a binary image.
|
||||
probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class.
|
||||
|
||||
Attributes:
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
||||
orig_img (tuple, optional): Original image size.
|
||||
data (torch.Tensor): The raw masks tensor
|
||||
|
||||
"""
|
||||
Attributes:
|
||||
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||
orig_shape (tuple): The original image shape in (height, width) format.
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class.
|
||||
names (List[str]): A list of class names.
|
||||
path (str): The path to the image file.
|
||||
_keys (tuple): A tuple of attribute names for non-empty attributes.
|
||||
"""
|
||||
|
||||
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None) -> None:
|
||||
self.orig_img = orig_img
|
||||
self.orig_shape = orig_img.shape[:2]
|
||||
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = probs if probs is not None else None
|
||||
self.boxes = Boxes(boxes.cpu(), self.orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks.cpu(), self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = probs.cpu() if probs is not None else None
|
||||
self.names = names
|
||||
self.path = path
|
||||
self._keys = (k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None)
|
||||
@ -99,24 +103,22 @@ class Results:
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"""
|
||||
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
|
||||
|
||||
Attributes:
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
||||
orig_shape (tuple, optional): Original image size.
|
||||
""")
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||
"""
|
||||
Plots the given result on an input RGB image. Accepts cv2(numpy) or PIL Image
|
||||
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
|
||||
|
||||
Args:
|
||||
show_conf (bool): Show confidence
|
||||
line_width (Float): The line width of boxes. Automatically scaled to img size if not provided
|
||||
font_size (Float): The font size of . Automatically scaled to img size if not provided
|
||||
show_conf (bool): Whether to show the detection confidence score.
|
||||
line_width (float, optional): The line width of the bounding boxes. If None, it is automatically scaled to the image size.
|
||||
font_size (float, optional): The font size of the text. If None, it is automatically scaled to the image size.
|
||||
font (str): The font to use for the text.
|
||||
pil (bool): Whether to return the image as a PIL Image.
|
||||
example (str): An example string to display in the plot. Useful for indicating the expected format of the output.
|
||||
|
||||
Returns:
|
||||
None or PIL Image: If `pil` is True, the image will be returned as a PIL Image. Otherwise, nothing is returned.
|
||||
"""
|
||||
img = deepcopy(self.orig_img)
|
||||
annotator = Annotator(img, line_width, font_size, font, pil, example)
|
||||
@ -157,15 +159,24 @@ class Boxes:
|
||||
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
||||
with shape (num_boxes, 6).
|
||||
orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width).
|
||||
is_track (bool): True if the boxes also include track IDs, False otherwise.
|
||||
|
||||
Properties:
|
||||
xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format.
|
||||
conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes.
|
||||
cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes.
|
||||
id (torch.Tensor) or (numpy.ndarray): The track IDs of the boxes (if available).
|
||||
xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format.
|
||||
xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size.
|
||||
xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size.
|
||||
data (torch.Tensor): The raw bboxes tensor
|
||||
|
||||
Methods:
|
||||
cpu(): Move the object to CPU memory.
|
||||
numpy(): Convert the object to a numpy array.
|
||||
cuda(): Move the object to CUDA memory.
|
||||
to(*args, **kwargs): Move the object to the specified device.
|
||||
pandas(): Convert the object to a pandas DataFrame (not yet implemented).
|
||||
"""
|
||||
|
||||
def __init__(self, boxes, orig_shape) -> None:
|
||||
@ -257,22 +268,7 @@ class Boxes:
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"""
|
||||
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
|
||||
|
||||
Attributes:
|
||||
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
||||
with shape (num_boxes, 6).
|
||||
orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format.
|
||||
conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes.
|
||||
cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes.
|
||||
xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format.
|
||||
xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size.
|
||||
xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size.
|
||||
""")
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
|
||||
class Masks:
|
||||
@ -288,7 +284,18 @@ class Masks:
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks.
|
||||
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection masks.
|
||||
|
||||
Methods:
|
||||
cpu(): Returns a copy of the masks tensor on CPU memory.
|
||||
numpy(): Returns a copy of the masks tensor as a numpy array.
|
||||
cuda(): Returns a copy of the masks tensor on GPU memory.
|
||||
to(): Returns a copy of the masks tensor with the specified device and dtype.
|
||||
__len__(): Returns the number of masks in the tensor.
|
||||
__str__(): Returns a string representation of the masks tensor.
|
||||
__repr__(): Returns a detailed string representation of the masks tensor.
|
||||
__getitem__(): Returns a new Masks object with the masks at the specified index.
|
||||
__getattr__(): Raises an AttributeError with a list of valid attributes and properties.
|
||||
"""
|
||||
|
||||
def __init__(self, masks, orig_shape) -> None:
|
||||
@ -337,13 +344,4 @@ class Masks:
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
raise AttributeError(f"""
|
||||
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
|
||||
|
||||
Attributes:
|
||||
masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks.
|
||||
""")
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
@ -44,7 +44,6 @@ class BaseTrainer:
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the trainer.
|
||||
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
||||
console (logging.Logger): Logger instance.
|
||||
validator (BaseValidator): Validator instance.
|
||||
model (nn.Module): Model instance.
|
||||
callbacks (defaultdict): Dictionary of callbacks.
|
||||
@ -84,7 +83,6 @@ class BaseTrainer:
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.device = select_device(self.args.device, self.args.batch)
|
||||
self.check_resume()
|
||||
self.console = LOGGER
|
||||
self.validator = None
|
||||
self.model = None
|
||||
self.metrics = None
|
||||
@ -180,11 +178,12 @@ class BaseTrainer:
|
||||
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
||||
cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans
|
||||
try:
|
||||
LOGGER.info(f'Running DDP command {cmd}')
|
||||
subprocess.run(cmd, check=True)
|
||||
except Exception as e:
|
||||
self.console.warning(e)
|
||||
LOGGER.warning(e)
|
||||
finally:
|
||||
ddp_cleanup(self, file)
|
||||
ddp_cleanup(self, str(file))
|
||||
else:
|
||||
self._do_train(RANK, world_size)
|
||||
|
||||
@ -193,7 +192,7 @@ class BaseTrainer:
|
||||
# os.environ['MASTER_PORT'] = '9020'
|
||||
torch.cuda.set_device(rank)
|
||||
self.device = torch.device('cuda', rank)
|
||||
self.console.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
LOGGER.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size)
|
||||
|
||||
def _setup_train(self, rank, world_size):
|
||||
@ -262,10 +261,10 @@ class BaseTrainer:
|
||||
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
||||
last_opt_step = -1
|
||||
self.run_callbacks('on_train_start')
|
||||
self.log(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||
f'Starting training for {self.epochs} epochs...')
|
||||
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||
f'Starting training for {self.epochs} epochs...')
|
||||
if self.args.close_mosaic:
|
||||
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
||||
@ -278,14 +277,14 @@ class BaseTrainer:
|
||||
pbar = enumerate(self.train_loader)
|
||||
# Update dataloader attributes (optional)
|
||||
if epoch == (self.epochs - self.args.close_mosaic):
|
||||
self.console.info('Closing dataloader mosaic')
|
||||
LOGGER.info('Closing dataloader mosaic')
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
|
||||
if rank in {-1, 0}:
|
||||
self.console.info(self.progress_string())
|
||||
LOGGER.info(self.progress_string())
|
||||
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
|
||||
self.tloss = None
|
||||
self.optimizer.zero_grad()
|
||||
@ -372,12 +371,11 @@ class BaseTrainer:
|
||||
|
||||
if rank in {-1, 0}:
|
||||
# Do final val with best.pt
|
||||
self.log(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
||||
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
||||
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
||||
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
||||
self.final_eval()
|
||||
if self.args.plots:
|
||||
self.plot_metrics()
|
||||
self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
self.run_callbacks('on_train_end')
|
||||
torch.cuda.empty_cache()
|
||||
self.run_callbacks('teardown')
|
||||
@ -450,18 +448,6 @@ class BaseTrainer:
|
||||
self.best_fitness = fitness
|
||||
return metrics, fitness
|
||||
|
||||
def log(self, text, rank=-1):
|
||||
"""
|
||||
Logs the given text to given ranks process if provided, otherwise logs to all ranks.
|
||||
|
||||
Args"
|
||||
text (str): text to log
|
||||
rank (List[Int]): process rank
|
||||
|
||||
"""
|
||||
if rank in {-1, 0}:
|
||||
self.console.info(text)
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
||||
@ -521,7 +507,7 @@ class BaseTrainer:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if f is self.best:
|
||||
self.console.info(f'\nValidating {f}...')
|
||||
LOGGER.info(f'\nValidating {f}...')
|
||||
self.metrics = self.validator(model=f)
|
||||
self.metrics.pop('fitness', None)
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
@ -564,7 +550,7 @@ class BaseTrainer:
|
||||
self.best_fitness = best_fitness
|
||||
self.start_epoch = start_epoch
|
||||
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||
self.console.info('Closing dataloader mosaic')
|
||||
LOGGER.info('Closing dataloader mosaic')
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
|
@ -44,7 +44,6 @@ class BaseValidator:
|
||||
Attributes:
|
||||
dataloader (DataLoader): Dataloader to use for validation.
|
||||
pbar (tqdm): Progress bar to update during validation.
|
||||
logger (logging.Logger): Logger to use for validation.
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
model (nn.Module): Model to validate.
|
||||
data (dict): Data dictionary.
|
||||
@ -56,7 +55,7 @@ class BaseValidator:
|
||||
save_dir (Path): Directory to save results.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
|
||||
"""
|
||||
Initializes a BaseValidator instance.
|
||||
|
||||
@ -69,14 +68,13 @@ class BaseValidator:
|
||||
"""
|
||||
self.dataloader = dataloader
|
||||
self.pbar = pbar
|
||||
self.logger = logger or LOGGER
|
||||
self.args = args or get_cfg(DEFAULT_CFG)
|
||||
self.model = None
|
||||
self.data = None
|
||||
self.device = None
|
||||
self.batch_i = None
|
||||
self.training = True
|
||||
self.speed = None
|
||||
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||
self.jdict = None
|
||||
|
||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||
@ -123,14 +121,14 @@ class BaseValidator:
|
||||
self.device = model.device
|
||||
if not pt and not jit:
|
||||
self.args.batch = 1 # export.py models default to batch-size 1
|
||||
self.logger.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||
|
||||
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
elif self.args.task == 'classify':
|
||||
self.data = check_cls_dataset(self.args.data)
|
||||
else:
|
||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌"))
|
||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||
|
||||
if self.device.type == 'cpu':
|
||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||
@ -179,7 +177,7 @@ class BaseValidator:
|
||||
stats = self.get_stats()
|
||||
self.check_stats(stats)
|
||||
self.print_results()
|
||||
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
|
||||
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
|
||||
self.finalize_metrics()
|
||||
self.run_callbacks('on_val_end')
|
||||
if self.training:
|
||||
@ -187,11 +185,11 @@ class BaseValidator:
|
||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
|
||||
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||
else:
|
||||
self.logger.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
|
||||
self.speed)
|
||||
LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
|
||||
tuple(self.speed.values()))
|
||||
if self.args.save_json and self.jdict:
|
||||
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
|
||||
self.logger.info(f'Saving {f.name}...')
|
||||
LOGGER.info(f'Saving {f.name}...')
|
||||
json.dump(self.jdict, f) # flatten and save
|
||||
stats = self.eval_json(stats) # update stats
|
||||
if self.args.plots or self.args.save_json:
|
||||
|
@ -60,12 +60,12 @@ def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
||||
|
||||
# Export
|
||||
if format == '-':
|
||||
filename = model.ckpt_path
|
||||
filename = model.ckpt_path or model.cfg
|
||||
export = model # PyTorch format
|
||||
else:
|
||||
filename = model.export(imgsz=imgsz, format=format, half=half, device=device) # all others
|
||||
export = YOLO(filename)
|
||||
assert suffix in str(filename), 'export failed'
|
||||
assert suffix in str(filename), 'export failed'
|
||||
|
||||
# Predict
|
||||
if not (ROOT / 'assets/bus.jpg').exists():
|
||||
|
@ -29,7 +29,7 @@ def on_pretrain_routine_start(trainer):
|
||||
auto_connect_frameworks={'pytorch': False})
|
||||
task.connect(vars(trainer.args), name='General')
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ ClearML not initialized correctly, not logging this run. {e}')
|
||||
LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}')
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
@ -41,9 +41,9 @@ def on_fit_epoch_end(trainer):
|
||||
task = Task.current_task()
|
||||
if task and trainer.epoch == 0:
|
||||
model_info = {
|
||||
'Parameters': get_num_params(trainer.model),
|
||||
'GFLOPs': round(get_flops(trainer.model), 3),
|
||||
'Inference speed (ms/img)': round(trainer.validator.speed[1], 3)}
|
||||
'model/parameters': get_num_params(trainer.model),
|
||||
'model/GFLOPs': round(get_flops(trainer.model), 3),
|
||||
'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
|
||||
task.connect(model_info, name='Model')
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@ def on_pretrain_routine_start(trainer):
|
||||
experiment = comet_ml.Experiment(project_name=trainer.args.project or 'YOLOv8')
|
||||
experiment.log_parameters(vars(trainer.args))
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ Comet not initialized correctly, not logging this run. {e}')
|
||||
LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}')
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
@ -36,7 +36,7 @@ def on_fit_epoch_end(trainer):
|
||||
model_info = {
|
||||
'model/parameters': get_num_params(trainer.model),
|
||||
'model/GFLOPs': round(get_flops(trainer.model), 3),
|
||||
'model/speed(ms)': round(trainer.validator.speed[1], 3)}
|
||||
'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
|
||||
experiment.log_metrics(model_info, step=trainer.epoch + 1)
|
||||
|
||||
|
||||
|
@ -2,17 +2,24 @@
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
|
||||
writer = None # TensorBoard SummaryWriter instance
|
||||
|
||||
|
||||
def _log_scalars(scalars, step=0):
|
||||
for k, v in scalars.items():
|
||||
writer.add_scalar(k, v, step)
|
||||
if writer:
|
||||
for k, v in scalars.items():
|
||||
writer.add_scalar(k, v, step)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
global writer
|
||||
writer = SummaryWriter(str(trainer.save_dir))
|
||||
try:
|
||||
writer = SummaryWriter(str(trainer.save_dir))
|
||||
except Exception as e:
|
||||
writer = None # TensorBoard SummaryWriter instance
|
||||
LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
|
||||
|
||||
|
||||
def on_batch_end(trainer):
|
||||
|
@ -254,7 +254,7 @@ def check_file(file, suffix='', download=True):
|
||||
return file
|
||||
else: # search
|
||||
files = []
|
||||
for d in 'models', 'datasets', 'tracker/cfg': # search directories
|
||||
for d in 'models', 'datasets', 'tracker/cfg', 'yolo/cfg': # search directories
|
||||
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
|
||||
if not files:
|
||||
raise FileNotFoundError(f"'{file}' does not exist")
|
||||
|
@ -51,10 +51,9 @@ def generate_ddp_command(world_size, trainer):
|
||||
file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0])
|
||||
|
||||
# Build command
|
||||
torch_distributed_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
|
||||
cmd = [
|
||||
sys.executable, '-m', torch_distributed_cmd, '--nproc_per_node', f'{world_size}', '--master_port',
|
||||
f'{find_free_network_port()}', file] + args
|
||||
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
|
||||
port = find_free_network_port()
|
||||
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args
|
||||
return cmd, file
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ import requests
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
from ultralytics.yolo.utils import LOGGER, checks
|
||||
|
||||
GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \
|
||||
[f'yolov5{size}u.pt' for size in 'nsmlx'] + \
|
||||
@ -87,7 +87,7 @@ def safe_download(url,
|
||||
try:
|
||||
if curl or i > 0: # curl download with retry, continue
|
||||
s = 'sS' * (not progress) # silent
|
||||
r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '9', '-C', '-']).returncode
|
||||
r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '3', '-C', '-']).returncode
|
||||
assert r == 0, f'Curl return value {r}'
|
||||
else: # urllib download
|
||||
method = 'torch'
|
||||
@ -112,8 +112,10 @@ def safe_download(url,
|
||||
break # success
|
||||
f.unlink() # remove partial downloads
|
||||
except Exception as e:
|
||||
if i >= retry:
|
||||
raise ConnectionError(f'❌ Download failure for {url}') from e
|
||||
if i == 0 and not checks.check_online():
|
||||
raise ConnectionError(f'❌ Download failure for {url}. Environment is not online.') from e
|
||||
elif i >= retry:
|
||||
raise ConnectionError(f'❌ Download failure for {url}. Retry limit reached.') from e
|
||||
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
||||
|
||||
if unzip and f.exists() and f.suffix in {'.zip', '.tar', '.gz'}:
|
||||
|
@ -7,7 +7,7 @@ from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
|
||||
from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.data import build_classification_dataloader
|
||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, RANK
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
|
||||
from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer
|
||||
|
||||
|
||||
@ -64,6 +64,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
||||
else:
|
||||
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
||||
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
|
||||
|
||||
return # dont return ckpt. Classification doesn't support resume
|
||||
|
||||
@ -93,7 +94,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def get_validator(self):
|
||||
self.loss_names = ['loss']
|
||||
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
|
||||
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
|
||||
@ -132,11 +133,12 @@ class ClassificationTrainer(BaseTrainer):
|
||||
strip_optimizer(f) # strip optimizers
|
||||
# TODO: validate best.pt after training completes
|
||||
# if f is self.best:
|
||||
# self.console.info(f'\nValidating {f}...')
|
||||
# LOGGER.info(f'\nValidating {f}...')
|
||||
# self.validator.args.save_json = True
|
||||
# self.metrics = self.validator(model=f)
|
||||
# self.metrics.pop('fitness', None)
|
||||
# self.run_callbacks('on_fit_epoch_end')
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
|
@ -2,14 +2,14 @@
|
||||
|
||||
from ultralytics.yolo.data import build_classification_dataloader
|
||||
from ultralytics.yolo.engine.validator import BaseValidator
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER
|
||||
from ultralytics.yolo.utils.metrics import ClassifyMetrics
|
||||
|
||||
|
||||
class ClassificationValidator(BaseValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
|
||||
super().__init__(dataloader, save_dir, pbar, args)
|
||||
self.args.task = 'classify'
|
||||
self.metrics = ClassifyMetrics()
|
||||
|
||||
@ -31,7 +31,7 @@ class ClassificationValidator(BaseValidator):
|
||||
self.targets.append(batch['cls'])
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
self.metrics.speed = dict(zip(self.metrics.speed.keys(), self.speed))
|
||||
self.metrics.speed = self.speed
|
||||
|
||||
def get_stats(self):
|
||||
self.metrics.process(self.targets, self.pred)
|
||||
@ -45,7 +45,7 @@ class ClassificationValidator(BaseValidator):
|
||||
|
||||
def print_results(self):
|
||||
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
|
||||
self.logger.info(pf % ('all', self.metrics.top1, self.metrics.top5))
|
||||
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
|
||||
|
||||
|
||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||
|
@ -66,10 +66,7 @@ class DetectionTrainer(BaseTrainer):
|
||||
|
||||
def get_validator(self):
|
||||
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
|
||||
return v8.detect.DetectionValidator(self.test_loader,
|
||||
save_dir=self.save_dir,
|
||||
logger=self.console,
|
||||
args=copy(self.args))
|
||||
return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
if not hasattr(self, 'compute_loss'):
|
||||
|
@ -9,7 +9,7 @@ import torch
|
||||
from ultralytics.yolo.data import build_dataloader
|
||||
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
|
||||
from ultralytics.yolo.engine.validator import BaseValidator
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, colorstr, ops
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, colorstr, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||
@ -18,8 +18,8 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
|
||||
class DetectionValidator(BaseValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
|
||||
super().__init__(dataloader, save_dir, pbar, args)
|
||||
self.args.task = 'detect'
|
||||
self.is_coco = False
|
||||
self.class_map = None
|
||||
@ -112,7 +112,7 @@ class DetectionValidator(BaseValidator):
|
||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
self.metrics.speed = dict(zip(self.metrics.speed.keys(), self.speed))
|
||||
self.metrics.speed = self.speed
|
||||
|
||||
def get_stats(self):
|
||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
|
||||
@ -123,15 +123,15 @@ class DetectionValidator(BaseValidator):
|
||||
|
||||
def print_results(self):
|
||||
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
|
||||
self.logger.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
||||
LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
||||
if self.nt_per_class.sum() == 0:
|
||||
self.logger.warning(
|
||||
LOGGER.warning(
|
||||
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
|
||||
|
||||
# Print results per class
|
||||
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
|
||||
for i, c in enumerate(self.metrics.ap_class_index):
|
||||
self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
|
||||
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
|
||||
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
|
||||
@ -212,7 +212,7 @@ class DetectionValidator(BaseValidator):
|
||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
||||
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||
check_requirements('pycocotools>=2.0.6')
|
||||
from pycocotools.coco import COCO # noqa
|
||||
@ -230,7 +230,7 @@ class DetectionValidator(BaseValidator):
|
||||
eval.summarize()
|
||||
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||
except Exception as e:
|
||||
self.logger.warning(f'pycocotools unable to run: {e}')
|
||||
LOGGER.warning(f'pycocotools unable to run: {e}')
|
||||
return stats
|
||||
|
||||
|
||||
|
@ -68,11 +68,10 @@ class SegmentationPredictor(DetectionPredictor):
|
||||
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
|
||||
|
||||
# Mask plotting
|
||||
self.annotator.masks(
|
||||
mask.masks,
|
||||
colors=[colors(x, True) for x in det.cls],
|
||||
im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(self.device).permute(2, 0, 1).flip(0).contiguous() /
|
||||
255 if self.args.retina_masks else im[idx])
|
||||
if self.args.save or self.args.show:
|
||||
im_gpu = torch.as_tensor(im0, dtype=torch.float16, device=mask.masks.device).permute(
|
||||
2, 0, 1).flip(0).contiguous() / 255 if self.args.retina_masks else im[idx]
|
||||
self.annotator.masks(masks=mask.masks, colors=[colors(x, True) for x in det.cls], im_gpu=im_gpu)
|
||||
|
||||
# Write results
|
||||
for j, d in enumerate(reversed(det)):
|
||||
|
@ -32,10 +32,7 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
|
||||
|
||||
def get_validator(self):
|
||||
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
|
||||
return v8.segment.SegmentationValidator(self.test_loader,
|
||||
save_dir=self.save_dir,
|
||||
logger=self.console,
|
||||
args=copy(self.args))
|
||||
return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
if not hasattr(self, 'compute_loss'):
|
||||
@ -86,10 +83,6 @@ class SegLoss(Loss):
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
|
||||
masks = batch['masks'].to(self.device).float()
|
||||
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
||||
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
|
||||
|
||||
# pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
|
||||
@ -103,10 +96,15 @@ class SegLoss(Loss):
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# bbox loss
|
||||
if fg_mask.sum():
|
||||
# bbox loss
|
||||
loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
|
||||
target_scores, target_scores_sum, fg_mask)
|
||||
# masks loss
|
||||
masks = batch['masks'].to(self.device).float()
|
||||
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
||||
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
|
||||
|
||||
for i in range(batch_size):
|
||||
if fg_mask[i].sum():
|
||||
mask_idx = target_gt_idx[i][fg_mask[i]]
|
||||
@ -121,9 +119,9 @@ class SegLoss(Loss):
|
||||
marea) # seg loss
|
||||
# WARNING: Uncomment lines below in case of Multi-GPU DDP unused gradient errors
|
||||
# else:
|
||||
# loss[1] += proto.sum() * 0
|
||||
# loss[1] += proto.sum() * 0 + pred_masks.sum() * 0
|
||||
# else:
|
||||
# loss[1] += proto.sum() * 0
|
||||
# loss[1] += proto.sum() * 0 + pred_masks.sum() * 0
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.box / batch_size # seg gain
|
||||
|
@ -7,7 +7,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, NUM_THREADS, ops
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, NUM_THREADS, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
from ultralytics.yolo.utils.metrics import SegmentMetrics, box_iou, mask_iou
|
||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||
@ -16,8 +16,8 @@ from ultralytics.yolo.v8.detect import DetectionValidator
|
||||
|
||||
class SegmentationValidator(DetectionValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
|
||||
super().__init__(dataloader, save_dir, pbar, args)
|
||||
self.args.task = 'segment'
|
||||
self.metrics = SegmentMetrics(save_dir=self.save_dir)
|
||||
|
||||
@ -120,7 +120,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
self.metrics.speed = dict(zip(self.metrics.speed.keys(), self.speed))
|
||||
self.metrics.speed = self.speed
|
||||
|
||||
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
||||
"""
|
||||
@ -207,7 +207,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
||||
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||
check_requirements('pycocotools>=2.0.6')
|
||||
from pycocotools.coco import COCO # noqa
|
||||
@ -228,7 +228,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
stats[self.metrics.keys[idx + 1]], stats[
|
||||
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||
except Exception as e:
|
||||
self.logger.warning(f'pycocotools unable to run: {e}')
|
||||
LOGGER.warning(f'pycocotools unable to run: {e}')
|
||||
return stats
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user