mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
ultralytics 8.0.44
export and task fixes (#1088)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
fe61018975
commit
3ea659411b
7
.github/workflows/ci.yaml
vendored
7
.github/workflows/ci.yaml
vendored
@ -32,9 +32,14 @@ jobs:
|
|||||||
# key: ${{ runner.os }}-Benchmarks-${{ hashFiles('requirements.txt') }}
|
# key: ${{ runner.os }}-Benchmarks-${{ hashFiles('requirements.txt') }}
|
||||||
# restore-keys: ${{ runner.os }}-Benchmarks-
|
# restore-keys: ${{ runner.os }}-Benchmarks-
|
||||||
- name: Install requirements
|
- name: Install requirements
|
||||||
|
shell: bash # for Windows compatibility
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip wheel
|
python -m pip install --upgrade pip wheel
|
||||||
|
if [ "${{ matrix.os }}" == "macos-latest" ]; then
|
||||||
|
pip install -e . coremltools openvino-dev tensorflow-macos --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
else
|
||||||
pip install -e . coremltools openvino-dev tensorflow-cpu paddlepaddle x2paddle --extra-index-url https://download.pytorch.org/whl/cpu
|
pip install -e . coremltools openvino-dev tensorflow-cpu paddlepaddle x2paddle --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
fi
|
||||||
yolo export format=tflite
|
yolo export format=tflite
|
||||||
- name: Check environment
|
- name: Check environment
|
||||||
run: |
|
run: |
|
||||||
@ -94,6 +99,7 @@ jobs:
|
|||||||
key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }}
|
key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }}
|
||||||
restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-pip-
|
restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-pip-
|
||||||
- name: Install requirements
|
- name: Install requirements
|
||||||
|
shell: bash # for Windows compatibility
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip wheel
|
python -m pip install --upgrade pip wheel
|
||||||
if [ "${{ matrix.torch }}" == "1.8.0" ]; then
|
if [ "${{ matrix.torch }}" == "1.8.0" ]; then
|
||||||
@ -101,7 +107,6 @@ jobs:
|
|||||||
else
|
else
|
||||||
pip install -e '.[export]' pytest --extra-index-url https://download.pytorch.org/whl/cpu
|
pip install -e '.[export]' pytest --extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
fi
|
fi
|
||||||
shell: bash # for Windows compatibility
|
|
||||||
- name: Check environment
|
- name: Check environment
|
||||||
run: |
|
run: |
|
||||||
echo "RUNNER_OS is ${{ runner.os }}"
|
echo "RUNNER_OS is ${{ runner.os }}"
|
||||||
|
@ -78,13 +78,6 @@
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"source": [],
|
|
||||||
"metadata": {
|
|
||||||
"id": "ZOwTlorPd8-D"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS
|
from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS, checks
|
||||||
|
|
||||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
|
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
|
||||||
CFG = 'yolov8n'
|
CFG = 'yolov8n'
|
||||||
@ -49,6 +49,7 @@ def test_val_classify():
|
|||||||
# Predict checks -------------------------------------------------------------------------------------------------------
|
# Predict checks -------------------------------------------------------------------------------------------------------
|
||||||
def test_predict_detect():
|
def test_predict_detect():
|
||||||
run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
|
run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32")
|
||||||
|
if checks.check_online():
|
||||||
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32')
|
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32')
|
||||||
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32')
|
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32')
|
||||||
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32')
|
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32')
|
||||||
|
@ -9,7 +9,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
from ultralytics.yolo.data.build import load_inference_source
|
from ultralytics.yolo.data.build import load_inference_source
|
||||||
from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS
|
from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS, checks
|
||||||
|
|
||||||
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
|
||||||
CFG = 'yolov8n.yaml'
|
CFG = 'yolov8n.yaml'
|
||||||
@ -49,28 +49,20 @@ def test_predict_dir():
|
|||||||
|
|
||||||
def test_predict_img():
|
def test_predict_img():
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
output = model(source=Image.open(SOURCE), save=True, verbose=True) # PIL
|
im = cv2.imread(str(SOURCE))
|
||||||
assert len(output) == 1, 'predict test failed'
|
assert len(model(source=Image.open(SOURCE), save=True, verbose=True)) == 1 # PIL
|
||||||
img = cv2.imread(str(SOURCE))
|
assert len(model(source=im, save=True, save_txt=True)) == 1 # ndarray
|
||||||
output = model(source=img, save=True, save_txt=True) # ndarray
|
assert len(model(source=[im, im], save=True, save_txt=True)) == 2 # batch
|
||||||
assert len(output) == 1, 'predict test failed'
|
assert len(list(model(source=[im, im], save=True, stream=True))) == 2 # stream
|
||||||
output = model(source=[img, img], save=True, save_txt=True) # batch
|
assert len(model(torch.zeros(320, 640, 3).numpy())) == 1 # tensor to numpy
|
||||||
assert len(output) == 2, 'predict test failed'
|
batch = [
|
||||||
output = model(source=[img, img], save=True, stream=True) # stream
|
str(SOURCE), # filename
|
||||||
assert len(list(output)) == 2, 'predict test failed'
|
|
||||||
tens = torch.zeros(320, 640, 3)
|
|
||||||
output = model(tens.numpy())
|
|
||||||
assert len(output) == 1, 'predict test failed'
|
|
||||||
# test multiple source
|
|
||||||
imgs = [
|
|
||||||
SOURCE, # filename
|
|
||||||
Path(SOURCE), # Path
|
Path(SOURCE), # Path
|
||||||
'https://ultralytics.com/images/zidane.jpg', # URI
|
'https://ultralytics.com/images/zidane.jpg' if checks.check_online() else SOURCE, # URI
|
||||||
cv2.imread(str(SOURCE)), # OpenCV
|
cv2.imread(str(SOURCE)), # OpenCV
|
||||||
Image.open(SOURCE), # PIL
|
Image.open(SOURCE), # PIL
|
||||||
np.zeros((320, 640, 3))] # numpy
|
np.zeros((320, 640, 3))] # numpy
|
||||||
output = model(imgs)
|
assert len(model(batch)) == len(batch) # multiple sources in a batch
|
||||||
assert len(output) == 6, 'predict test failed!'
|
|
||||||
|
|
||||||
|
|
||||||
def test_predict_grey_and_4ch():
|
def test_predict_grey_and_4ch():
|
||||||
@ -85,6 +77,11 @@ def test_val():
|
|||||||
model.val(data='coco8.yaml', imgsz=32)
|
model.val(data='coco8.yaml', imgsz=32)
|
||||||
|
|
||||||
|
|
||||||
|
def test_val_scratch():
|
||||||
|
model = YOLO(CFG)
|
||||||
|
model.val(data='coco8.yaml', imgsz=32)
|
||||||
|
|
||||||
|
|
||||||
def test_train_scratch():
|
def test_train_scratch():
|
||||||
model = YOLO(CFG)
|
model = YOLO(CFG)
|
||||||
model.train(data='coco8.yaml', epochs=1, imgsz=32)
|
model.train(data='coco8.yaml', epochs=1, imgsz=32)
|
||||||
@ -103,6 +100,12 @@ def test_export_torchscript():
|
|||||||
YOLO(f)(SOURCE) # exported model inference
|
YOLO(f)(SOURCE) # exported model inference
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_torchscript_scratch():
|
||||||
|
model = YOLO(CFG)
|
||||||
|
f = model.export(format='torchscript')
|
||||||
|
YOLO(f)(SOURCE) # exported model inference
|
||||||
|
|
||||||
|
|
||||||
def test_export_onnx():
|
def test_export_onnx():
|
||||||
model = YOLO(MODEL)
|
model = YOLO(MODEL)
|
||||||
f = model.export(format='onnx')
|
f = model.export(format='onnx')
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||||
|
|
||||||
__version__ = '8.0.43'
|
__version__ = '8.0.44'
|
||||||
|
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
from ultralytics.yolo.utils.checks import check_yolo as checks
|
from ultralytics.yolo.utils.checks import check_yolo as checks
|
||||||
|
@ -15,7 +15,7 @@ EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ['ultralytics_tflite', 'ultralytics_c
|
|||||||
|
|
||||||
def start(key=''):
|
def start(key=''):
|
||||||
"""
|
"""
|
||||||
Start training models with Ultralytics HUB. Usage: from src.ultralytics import start; start('API_KEY')
|
Start training models with Ultralytics HUB. Usage: from ultralytics.hub import start; start('API_KEY')
|
||||||
"""
|
"""
|
||||||
auth = Auth(key)
|
auth = Auth(key)
|
||||||
try:
|
try:
|
||||||
@ -30,9 +30,9 @@ def start(key=''):
|
|||||||
session = HubTrainingSession(model_id=model_id, auth=auth)
|
session = HubTrainingSession(model_id=model_id, auth=auth)
|
||||||
session.check_disk_space()
|
session.check_disk_space()
|
||||||
|
|
||||||
trainer = YOLO(session.input_file)
|
model = YOLO(session.input_file)
|
||||||
session.register_callbacks(trainer)
|
session.register_callbacks(model)
|
||||||
trainer.train(**session.train_args)
|
model.train(**session.train_args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f'{PREFIX}{e}')
|
LOGGER.warning(f'{PREFIX}{e}')
|
||||||
|
|
||||||
@ -93,6 +93,5 @@ def get_export(key='', format='torchscript'):
|
|||||||
return r.json()
|
return r.json()
|
||||||
|
|
||||||
|
|
||||||
# temp. For checking
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
start()
|
start()
|
||||||
|
@ -26,6 +26,7 @@ class HubTrainingSession:
|
|||||||
self._timers = {} # rate limit timers (seconds)
|
self._timers = {} # rate limit timers (seconds)
|
||||||
self._metrics_queue = {} # metrics queue
|
self._metrics_queue = {} # metrics queue
|
||||||
self.model = self._get_model()
|
self.model = self._get_model()
|
||||||
|
self.alive = True
|
||||||
self._start_heartbeat() # start heartbeats
|
self._start_heartbeat() # start heartbeats
|
||||||
self._register_signal_handlers()
|
self._register_signal_handlers()
|
||||||
|
|
||||||
@ -52,37 +53,6 @@ class HubTrainingSession:
|
|||||||
payload = {'metrics': self._metrics_queue.copy(), 'type': 'metrics'}
|
payload = {'metrics': self._metrics_queue.copy(), 'type': 'metrics'}
|
||||||
smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2)
|
smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2)
|
||||||
|
|
||||||
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
|
||||||
# Upload a model to HUB
|
|
||||||
file = None
|
|
||||||
if Path(weights).is_file():
|
|
||||||
with open(weights, 'rb') as f:
|
|
||||||
file = f.read()
|
|
||||||
if final:
|
|
||||||
smart_request(
|
|
||||||
f'{self.api_url}/upload',
|
|
||||||
data={
|
|
||||||
'epoch': epoch,
|
|
||||||
'type': 'final',
|
|
||||||
'map': map},
|
|
||||||
files={'best.pt': file},
|
|
||||||
headers=self.auth_header,
|
|
||||||
retry=10,
|
|
||||||
timeout=3600,
|
|
||||||
code=4,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
smart_request(
|
|
||||||
f'{self.api_url}/upload',
|
|
||||||
data={
|
|
||||||
'epoch': epoch,
|
|
||||||
'type': 'epoch',
|
|
||||||
'isBest': bool(is_best)},
|
|
||||||
headers=self.auth_header,
|
|
||||||
files={'last.pt': file},
|
|
||||||
code=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_model(self):
|
def _get_model(self):
|
||||||
# Returns model from database by id
|
# Returns model from database by id
|
||||||
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
|
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
|
||||||
@ -151,7 +121,7 @@ class HubTrainingSession:
|
|||||||
model_info = {
|
model_info = {
|
||||||
'model/parameters': get_num_params(trainer.model),
|
'model/parameters': get_num_params(trainer.model),
|
||||||
'model/GFLOPs': round(get_flops(trainer.model), 3),
|
'model/GFLOPs': round(get_flops(trainer.model), 3),
|
||||||
'model/speed(ms)': round(trainer.validator.speed[1], 3)}
|
'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
|
||||||
all_plots = {**all_plots, **model_info}
|
all_plots = {**all_plots, **model_info}
|
||||||
self._metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
self._metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
||||||
if time() - self._timers['metrics'] > self._rate_limits['metrics']:
|
if time() - self._timers['metrics'] > self._rate_limits['metrics']:
|
||||||
@ -169,52 +139,45 @@ class HubTrainingSession:
|
|||||||
|
|
||||||
def on_train_end(self, trainer):
|
def on_train_end(self, trainer):
|
||||||
# Upload final model and metrics with exponential standoff
|
# Upload final model and metrics with exponential standoff
|
||||||
LOGGER.info(f'{PREFIX}Training completed successfully ✅')
|
LOGGER.info(f'{PREFIX}Training completed successfully ✅\n'
|
||||||
LOGGER.info(f'{PREFIX}Uploading final {self.model_id}')
|
f'{PREFIX}Uploading final {self.model_id}')
|
||||||
|
|
||||||
# hack for fetching mAP
|
self._upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
|
||||||
mAP = trainer.metrics.get('metrics/mAP50-95(B)', 0)
|
|
||||||
self._upload_model(trainer.epoch, trainer.best, map=mAP, final=True) # results[3] is mAP0.5:0.95
|
|
||||||
self.alive = False # stop heartbeats
|
self.alive = False # stop heartbeats
|
||||||
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
|
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
|
||||||
|
|
||||||
def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
|
||||||
# Upload a model to HUB
|
# Upload a model to HUB
|
||||||
file = None
|
|
||||||
if Path(weights).is_file():
|
if Path(weights).is_file():
|
||||||
with open(weights, 'rb') as f:
|
with open(weights, 'rb') as f:
|
||||||
file = f.read()
|
file = f.read()
|
||||||
file_param = {'best.pt' if final else 'last.pt': file}
|
else:
|
||||||
endpoint = f'{self.api_url}/upload'
|
LOGGER.warning(f'{PREFIX}WARNING ⚠️ Model upload failed. Missing model {weights}.')
|
||||||
|
file = None
|
||||||
data = {'epoch': epoch}
|
data = {'epoch': epoch}
|
||||||
if final:
|
if final:
|
||||||
data.update({'type': 'final', 'map': map})
|
data.update({'type': 'final', 'map': map})
|
||||||
else:
|
else:
|
||||||
data.update({'type': 'epoch', 'isBest': bool(is_best)})
|
data.update({'type': 'epoch', 'isBest': bool(is_best)})
|
||||||
|
|
||||||
smart_request(
|
smart_request(f'{self.api_url}/upload',
|
||||||
endpoint,
|
|
||||||
data=data,
|
data=data,
|
||||||
files=file_param,
|
files={'best.pt' if final else 'last.pt': file},
|
||||||
headers=self.auth_header,
|
headers=self.auth_header,
|
||||||
retry=10 if final else None,
|
retry=10 if final else None,
|
||||||
timeout=3600 if final else None,
|
timeout=3600 if final else None,
|
||||||
code=4 if final else 3,
|
code=4 if final else 3)
|
||||||
)
|
|
||||||
|
|
||||||
@threaded
|
@threaded
|
||||||
def _start_heartbeat(self):
|
def _start_heartbeat(self):
|
||||||
self.alive = True
|
|
||||||
while self.alive:
|
while self.alive:
|
||||||
r = smart_request(
|
r = smart_request(f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
|
||||||
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
|
|
||||||
json={
|
json={
|
||||||
'agent': AGENT_NAME,
|
'agent': AGENT_NAME,
|
||||||
'agentId': self.agent_id},
|
'agentId': self.agent_id},
|
||||||
headers=self.auth_header,
|
headers=self.auth_header,
|
||||||
retry=0,
|
retry=0,
|
||||||
code=5,
|
code=5,
|
||||||
thread=False,
|
thread=False)
|
||||||
)
|
|
||||||
self.agent_id = r.json().get('data', {}).get('agentId', None)
|
self.agent_id = r.json().get('data', {}).get('agentId', None)
|
||||||
sleep(self._rate_limits['heartbeat'])
|
sleep(self._rate_limits['heartbeat'])
|
||||||
|
@ -181,7 +181,6 @@ class AutoBackend(nn.Module):
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
keras = False # assume TF1 saved_model
|
keras = False # assume TF1 saved_model
|
||||||
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
|
||||||
w = Path(w) / 'metadata.yaml'
|
|
||||||
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
|
||||||
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@ -258,8 +257,9 @@ class AutoBackend(nn.Module):
|
|||||||
f'\n\n{EXPORT_FORMATS_TABLE}')
|
f'\n\n{EXPORT_FORMATS_TABLE}')
|
||||||
|
|
||||||
# Load external metadata YAML
|
# Load external metadata YAML
|
||||||
|
w = Path(w)
|
||||||
if xml or saved_model or paddle:
|
if xml or saved_model or paddle:
|
||||||
metadata = Path(w).parent / 'metadata.yaml'
|
metadata = (w if saved_model else w.parents[1] if paddle else w.parent) / 'metadata.yaml'
|
||||||
if metadata.exists():
|
if metadata.exists():
|
||||||
metadata = yaml_load(metadata)
|
metadata = yaml_load(metadata)
|
||||||
stride, names = int(metadata['stride']), metadata['names'] # load metadata
|
stride, names = int(metadata['stride']), metadata['names'] # load metadata
|
||||||
|
@ -287,6 +287,7 @@ class ClassificationModel(BaseModel):
|
|||||||
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
||||||
self.yaml['nc'] = nc # override yaml value
|
self.yaml['nc'] = nc # override yaml value
|
||||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
||||||
|
self.stride = torch.Tensor([1]) # no stride constraints
|
||||||
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
||||||
self.info()
|
self.info()
|
||||||
|
|
||||||
@ -520,14 +521,15 @@ def guess_model_task(model):
|
|||||||
|
|
||||||
# Guess from model filename
|
# Guess from model filename
|
||||||
if isinstance(model, (str, Path)):
|
if isinstance(model, (str, Path)):
|
||||||
model = Path(model).stem
|
model = Path(model)
|
||||||
if '-seg' in model:
|
if '-seg' in model.stem or 'segment' in model.parts:
|
||||||
return 'segment'
|
return 'segment'
|
||||||
elif '-cls' in model:
|
elif '-cls' in model.stem or 'classify' in model.parts:
|
||||||
return 'classify'
|
return 'classify'
|
||||||
else:
|
elif 'detect' in model.parts:
|
||||||
return 'detect'
|
return 'detect'
|
||||||
|
|
||||||
# Unable to determine task from model
|
# Unable to determine task from model
|
||||||
raise SyntaxError('YOLO is unable to automatically guess model task. Explicitly define task for your model, '
|
LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
|
||||||
"i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
"Explicitly define task for your model, i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
||||||
|
return 'detect' # assume detect
|
||||||
|
@ -47,11 +47,12 @@ CLI_HELP_MSG = \
|
|||||||
GitHub: https://github.com/ultralytics/ultralytics
|
GitHub: https://github.com/ultralytics/ultralytics
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear'}
|
# Define keys for arg type checks
|
||||||
|
CFG_FLOAT_KEYS = {'warmup_epochs', 'box', 'cls', 'dfl', 'degrees', 'shear', 'fl_gamma'}
|
||||||
CFG_FRACTION_KEYS = {
|
CFG_FRACTION_KEYS = {
|
||||||
'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'fl_gamma',
|
'dropout', 'iou', 'lr0', 'lrf', 'momentum', 'weight_decay', 'warmup_momentum', 'warmup_bias_lr', 'label_smoothing',
|
||||||
'label_smoothing', 'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic',
|
'hsv_h', 'hsv_s', 'hsv_v', 'translate', 'scale', 'perspective', 'flipud', 'fliplr', 'mosaic', 'mixup', 'copy_paste',
|
||||||
'mixup', 'copy_paste', 'conf', 'iou'}
|
'conf', 'iou'} # fractional floats limited to 0.0 - 1.0
|
||||||
CFG_INT_KEYS = {
|
CFG_INT_KEYS = {
|
||||||
'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
'epochs', 'patience', 'batch', 'workers', 'seed', 'close_mosaic', 'mask_ratio', 'max_det', 'vid_stride',
|
||||||
'line_thickness', 'workspace', 'nbs', 'save_period'}
|
'line_thickness', 'workspace', 'nbs', 'save_period'}
|
||||||
@ -224,7 +225,7 @@ def entrypoint(debug=''):
|
|||||||
assert v, f"missing '{k}' value"
|
assert v, f"missing '{k}' value"
|
||||||
if k == 'cfg': # custom.yaml passed
|
if k == 'cfg': # custom.yaml passed
|
||||||
LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
|
LOGGER.info(f'Overriding {DEFAULT_CFG_PATH} with {v}')
|
||||||
overrides = {k: val for k, val in yaml_load(v).items() if k != 'cfg'}
|
overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != 'cfg'}
|
||||||
else:
|
else:
|
||||||
if v.lower() == 'none':
|
if v.lower() == 'none':
|
||||||
v = None
|
v = None
|
||||||
@ -255,7 +256,6 @@ def entrypoint(debug=''):
|
|||||||
check_cfg_mismatch(full_args_dict, {a: ''})
|
check_cfg_mismatch(full_args_dict, {a: ''})
|
||||||
|
|
||||||
# Defaults
|
# Defaults
|
||||||
task2model = dict(detect='yolov8n.pt', segment='yolov8n-seg.pt', classify='yolov8n-cls.pt')
|
|
||||||
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100')
|
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100')
|
||||||
|
|
||||||
# Mode
|
# Mode
|
||||||
@ -272,27 +272,28 @@ def entrypoint(debug=''):
|
|||||||
|
|
||||||
# Model
|
# Model
|
||||||
model = overrides.pop('model', DEFAULT_CFG.model)
|
model = overrides.pop('model', DEFAULT_CFG.model)
|
||||||
task = overrides.pop('task', None)
|
|
||||||
if model is None:
|
if model is None:
|
||||||
model = task2model.get(task, 'yolov8n.pt')
|
model = 'yolov8n.pt'
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
overrides['model'] = model
|
overrides['model'] = model
|
||||||
model = YOLO(model)
|
model = YOLO(model)
|
||||||
|
|
||||||
# Task
|
# Task
|
||||||
if task and task != model.task:
|
# if task and task != model.task:
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'task={task}' conflicts with {model.task} model {overrides['model']}. "
|
# LOGGER.warning(f"WARNING ⚠️ 'task={task}' conflicts with {model.task} model {overrides['model']}. "
|
||||||
f"Inheriting 'task={model.task}' from {overrides['model']} and ignoring 'task={task}'.")
|
# f"Inheriting 'task={model.task}' from {overrides['model']} and ignoring 'task={task}'.")
|
||||||
task = model.task
|
overrides['task'] = overrides.get('task', model.task)
|
||||||
overrides['task'] = task
|
model.task = overrides['task']
|
||||||
|
|
||||||
|
# Mode
|
||||||
if mode in {'predict', 'track'} and 'source' not in overrides:
|
if mode in {'predict', 'track'} and 'source' not in overrides:
|
||||||
overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \
|
overrides['source'] = DEFAULT_CFG.source or ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||||
else 'https://ultralytics.com/images/bus.jpg'
|
else 'https://ultralytics.com/images/bus.jpg'
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
||||||
elif mode in ('train', 'val'):
|
elif mode in ('train', 'val'):
|
||||||
if 'data' not in overrides:
|
if 'data' not in overrides:
|
||||||
overrides['data'] = task2data.get(task, DEFAULT_CFG.data)
|
overrides['data'] = task2data.get(overrides['task'], DEFAULT_CFG.data)
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using {model.task} default 'data={overrides['data']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using {model.task} default 'data={overrides['data']}'.")
|
||||||
elif mode == 'export':
|
elif mode == 'export':
|
||||||
if 'format' not in overrides:
|
if 'format' not in overrides:
|
||||||
|
@ -6,7 +6,6 @@ Dataloaders and dataset utils
|
|||||||
import contextlib
|
import contextlib
|
||||||
import glob
|
import glob
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@ -27,11 +26,9 @@ from PIL import ExifTags, Image, ImageOps
|
|||||||
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
from torch.utils.data import DataLoader, Dataset, dataloader, distributed
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ultralytics.yolo.data.utils import check_det_dataset
|
|
||||||
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
|
from ultralytics.yolo.utils import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, is_colab, is_dir_writeable,
|
||||||
is_kaggle, yaml_load)
|
is_kaggle)
|
||||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
from ultralytics.yolo.utils.checks import check_requirements
|
||||||
from ultralytics.yolo.utils.downloads import unzip_file
|
|
||||||
from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
|
from ultralytics.yolo.utils.ops import clean_str, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn
|
||||||
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first
|
||||||
|
|
||||||
@ -1037,127 +1034,6 @@ def verify_image_label(args):
|
|||||||
return [None, None, None, None, nm, nf, ne, nc, msg]
|
return [None, None, None, None, nm, nf, ne, nc, msg]
|
||||||
|
|
||||||
|
|
||||||
class HUBDatasetStats():
|
|
||||||
""" Class for generating HUB dataset JSON and `-hub` dataset directory
|
|
||||||
|
|
||||||
Arguments
|
|
||||||
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
|
|
||||||
autodownload: Attempt to download dataset if not found locally
|
|
||||||
|
|
||||||
Usage
|
|
||||||
from ultralytics.yolo.data.dataloaders.v5loader import HUBDatasetStats
|
|
||||||
stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1
|
|
||||||
stats = HUBDatasetStats('path/to/coco128.zip') # usage 2
|
|
||||||
stats.get_json(save=False)
|
|
||||||
stats.process_images()
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, path='coco128.yaml', autodownload=False):
|
|
||||||
# Initialize class
|
|
||||||
zipped, data_dir, yaml_path = self._unzip(Path(path))
|
|
||||||
# try:
|
|
||||||
# data = yaml_load(check_yaml(yaml_path)) # data dict
|
|
||||||
# if zipped:
|
|
||||||
# data['path'] = data_dir
|
|
||||||
# except Exception as e:
|
|
||||||
# raise Exception('error/HUB/dataset_stats/yaml_load') from e
|
|
||||||
|
|
||||||
data = check_det_dataset(yaml_path, autodownload) # download dataset if missing
|
|
||||||
self.hub_dir = Path(str(data['path']) + '-hub')
|
|
||||||
self.im_dir = self.hub_dir / 'images'
|
|
||||||
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
|
||||||
self.stats = {'nc': data['nc'], 'names': list(data['names'].values())} # statistics dictionary
|
|
||||||
self.data = data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _find_yaml(dir):
|
|
||||||
# Return data.yaml file
|
|
||||||
files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
|
|
||||||
assert files, f'No *.yaml file found in {dir}'
|
|
||||||
if len(files) > 1:
|
|
||||||
files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
|
|
||||||
assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
|
|
||||||
assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
|
|
||||||
return files[0]
|
|
||||||
|
|
||||||
def _unzip(self, path):
|
|
||||||
# Unzip data.zip
|
|
||||||
if not str(path).endswith('.zip'): # path is data.yaml
|
|
||||||
return False, None, path
|
|
||||||
assert Path(path).is_file(), f'Error unzipping {path}, file not found'
|
|
||||||
unzip_file(path, path=path.parent)
|
|
||||||
dir = path.with_suffix('') # dataset directory == zip name
|
|
||||||
assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
|
|
||||||
return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
|
|
||||||
|
|
||||||
def _hub_ops(self, f, max_dim=1920):
|
|
||||||
# HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
|
|
||||||
f_new = self.im_dir / Path(f).name # dataset-hub image filename
|
|
||||||
try: # use PIL
|
|
||||||
im = Image.open(f)
|
|
||||||
r = max_dim / max(im.height, im.width) # ratio
|
|
||||||
if r < 1.0: # image too large
|
|
||||||
im = im.resize((int(im.width * r), int(im.height * r)))
|
|
||||||
im.save(f_new, 'JPEG', quality=50, optimize=True) # save
|
|
||||||
except Exception as e: # use OpenCV
|
|
||||||
LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
|
|
||||||
im = cv2.imread(f)
|
|
||||||
im_height, im_width = im.shape[:2]
|
|
||||||
r = max_dim / max(im_height, im_width) # ratio
|
|
||||||
if r < 1.0: # image too large
|
|
||||||
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
|
|
||||||
cv2.imwrite(str(f_new), im)
|
|
||||||
|
|
||||||
def get_json(self, save=False, verbose=False):
|
|
||||||
# Return dataset JSON for Ultralytics HUB
|
|
||||||
def _round(labels):
|
|
||||||
# Update labels to integer class and 6 decimal place floats
|
|
||||||
return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
|
|
||||||
|
|
||||||
for split in 'train', 'val', 'test':
|
|
||||||
if self.data.get(split) is None:
|
|
||||||
self.stats[split] = None # i.e. no test set
|
|
||||||
continue
|
|
||||||
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
|
||||||
x = np.array([
|
|
||||||
np.bincount(label[:, 0].astype(int), minlength=self.data['nc'])
|
|
||||||
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics')]) # shape(128x80)
|
|
||||||
self.stats[split] = {
|
|
||||||
'instance_stats': {
|
|
||||||
'total': int(x.sum()),
|
|
||||||
'per_class': x.sum(0).tolist()},
|
|
||||||
'image_stats': {
|
|
||||||
'total': dataset.n,
|
|
||||||
'unlabelled': int(np.all(x == 0, 1).sum()),
|
|
||||||
'per_class': (x > 0).sum(0).tolist()},
|
|
||||||
'labels': [{
|
|
||||||
str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
|
|
||||||
|
|
||||||
# Save, print and return
|
|
||||||
if save:
|
|
||||||
stats_path = self.hub_dir / 'stats.json'
|
|
||||||
LOGGER.info(f'Saving {stats_path.resolve()}...')
|
|
||||||
with open(stats_path, 'w') as f:
|
|
||||||
json.dump(self.stats, f) # save stats.json
|
|
||||||
if verbose:
|
|
||||||
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
|
|
||||||
return self.stats
|
|
||||||
|
|
||||||
def process_images(self):
|
|
||||||
# Compress images for Ultralytics HUB
|
|
||||||
for split in 'train', 'val', 'test':
|
|
||||||
if self.data.get(split) is None:
|
|
||||||
continue
|
|
||||||
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
|
||||||
desc = f'{split} images'
|
|
||||||
total = dataset.n
|
|
||||||
with ThreadPool(NUM_THREADS) as pool:
|
|
||||||
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=total, desc=desc):
|
|
||||||
pass
|
|
||||||
LOGGER.info(f'Done. All images saved to {self.im_dir}')
|
|
||||||
return self.im_dir
|
|
||||||
|
|
||||||
|
|
||||||
# Classification dataloaders -------------------------------------------------------------------------------------------
|
# Classification dataloaders -------------------------------------------------------------------------------------------
|
||||||
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||||
"""
|
"""
|
||||||
|
@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
|
from multiprocessing.pool import ThreadPool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tarfile import is_tarfile
|
from tarfile import is_tarfile
|
||||||
from zipfile import is_zipfile
|
from zipfile import is_zipfile
|
||||||
@ -12,10 +14,11 @@ from zipfile import is_zipfile
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import ExifTags, Image, ImageOps
|
from PIL import ExifTags, Image, ImageOps
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, ROOT, colorstr, emojis, yaml_load
|
from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, colorstr, emojis, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
|
from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii
|
||||||
from ultralytics.yolo.utils.downloads import download, safe_download
|
from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file
|
||||||
from ultralytics.yolo.utils.ops import segments2boxes
|
from ultralytics.yolo.utils.ops import segments2boxes
|
||||||
|
|
||||||
HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
|
||||||
@ -290,3 +293,128 @@ def check_cls_dataset(dataset: str):
|
|||||||
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
|
||||||
names = dict(enumerate(sorted(names)))
|
names = dict(enumerate(sorted(names)))
|
||||||
return {'train': train_set, 'val': test_set, 'nc': nc, 'names': names}
|
return {'train': train_set, 'val': test_set, 'nc': nc, 'names': names}
|
||||||
|
|
||||||
|
|
||||||
|
class HUBDatasetStats():
|
||||||
|
""" Class for generating HUB dataset JSON and `-hub` dataset directory
|
||||||
|
|
||||||
|
Arguments
|
||||||
|
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
|
||||||
|
autodownload: Attempt to download dataset if not found locally
|
||||||
|
|
||||||
|
Usage
|
||||||
|
from ultralytics.yolo.data.utils import HUBDatasetStats
|
||||||
|
stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1
|
||||||
|
stats = HUBDatasetStats('/Users/glennjocher/Downloads/coco6.zip') # usage 2
|
||||||
|
stats.get_json(save=False)
|
||||||
|
stats.process_images()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path='coco128.yaml', autodownload=False):
|
||||||
|
# Initialize class
|
||||||
|
zipped, data_dir, yaml_path = self._unzip(Path(path))
|
||||||
|
try:
|
||||||
|
# data = yaml_load(check_yaml(yaml_path)) # data dict
|
||||||
|
data = check_det_dataset(yaml_path, autodownload) # data dict
|
||||||
|
if zipped:
|
||||||
|
data['path'] = data_dir
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception('error/HUB/dataset_stats/yaml_load') from e
|
||||||
|
|
||||||
|
self.hub_dir = Path(str(data['path']) + '-hub')
|
||||||
|
self.im_dir = self.hub_dir / 'images'
|
||||||
|
self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
|
||||||
|
self.stats = {'nc': len(data['names']), 'names': list(data['names'].values())} # statistics dictionary
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _find_yaml(dir):
|
||||||
|
# Return data.yaml file
|
||||||
|
files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
|
||||||
|
assert files, f'No *.yaml file found in {dir}'
|
||||||
|
if len(files) > 1:
|
||||||
|
files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
|
||||||
|
assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
|
||||||
|
assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
|
||||||
|
return files[0]
|
||||||
|
|
||||||
|
def _unzip(self, path):
|
||||||
|
# Unzip data.zip
|
||||||
|
if not str(path).endswith('.zip'): # path is data.yaml
|
||||||
|
return False, None, path
|
||||||
|
assert Path(path).is_file(), f'Error unzipping {path}, file not found'
|
||||||
|
unzip_file(path, path=path.parent)
|
||||||
|
dir = path.with_suffix('') # dataset directory == zip name
|
||||||
|
assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
|
||||||
|
return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
|
||||||
|
|
||||||
|
def _hub_ops(self, f, max_dim=1920):
|
||||||
|
# HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
|
||||||
|
f_new = self.im_dir / Path(f).name # dataset-hub image filename
|
||||||
|
try: # use PIL
|
||||||
|
im = Image.open(f)
|
||||||
|
r = max_dim / max(im.height, im.width) # ratio
|
||||||
|
if r < 1.0: # image too large
|
||||||
|
im = im.resize((int(im.width * r), int(im.height * r)))
|
||||||
|
im.save(f_new, 'JPEG', quality=50, optimize=True) # save
|
||||||
|
except Exception as e: # use OpenCV
|
||||||
|
LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
|
||||||
|
im = cv2.imread(f)
|
||||||
|
im_height, im_width = im.shape[:2]
|
||||||
|
r = max_dim / max(im_height, im_width) # ratio
|
||||||
|
if r < 1.0: # image too large
|
||||||
|
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
|
||||||
|
cv2.imwrite(str(f_new), im)
|
||||||
|
|
||||||
|
def get_json(self, save=False, verbose=False):
|
||||||
|
# Return dataset JSON for Ultralytics HUB
|
||||||
|
# from ultralytics.yolo.data import YOLODataset
|
||||||
|
from ultralytics.yolo.data.dataloaders.v5loader import LoadImagesAndLabels
|
||||||
|
|
||||||
|
def _round(labels):
|
||||||
|
# Update labels to integer class and 6 decimal place floats
|
||||||
|
return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
|
||||||
|
|
||||||
|
for split in 'train', 'val', 'test':
|
||||||
|
if self.data.get(split) is None:
|
||||||
|
self.stats[split] = None # i.e. no test set
|
||||||
|
continue
|
||||||
|
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
||||||
|
x = np.array([
|
||||||
|
np.bincount(label[:, 0].astype(int), minlength=self.data['nc'])
|
||||||
|
for label in tqdm(dataset.labels, total=len(dataset), desc='Statistics')]) # shape(128x80)
|
||||||
|
self.stats[split] = {
|
||||||
|
'instance_stats': {
|
||||||
|
'total': int(x.sum()),
|
||||||
|
'per_class': x.sum(0).tolist()},
|
||||||
|
'image_stats': {
|
||||||
|
'total': len(dataset),
|
||||||
|
'unlabelled': int(np.all(x == 0, 1).sum()),
|
||||||
|
'per_class': (x > 0).sum(0).tolist()},
|
||||||
|
'labels': [{
|
||||||
|
str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
|
||||||
|
|
||||||
|
# Save, print and return
|
||||||
|
if save:
|
||||||
|
stats_path = self.hub_dir / 'stats.json'
|
||||||
|
LOGGER.info(f'Saving {stats_path.resolve()}...')
|
||||||
|
with open(stats_path, 'w') as f:
|
||||||
|
json.dump(self.stats, f) # save stats.json
|
||||||
|
if verbose:
|
||||||
|
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
|
||||||
|
return self.stats
|
||||||
|
|
||||||
|
def process_images(self):
|
||||||
|
# Compress images for Ultralytics HUB
|
||||||
|
# from ultralytics.yolo.data import YOLODataset
|
||||||
|
from ultralytics.yolo.data.dataloaders.v5loader import LoadImagesAndLabels
|
||||||
|
|
||||||
|
for split in 'train', 'val', 'test':
|
||||||
|
if self.data.get(split) is None:
|
||||||
|
continue
|
||||||
|
dataset = LoadImagesAndLabels(self.data[split]) # load dataset
|
||||||
|
with ThreadPool(NUM_THREADS) as pool:
|
||||||
|
for _ in tqdm(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f'{split} images'):
|
||||||
|
pass
|
||||||
|
LOGGER.info(f'Done. All images saved to {self.im_dir}')
|
||||||
|
return self.im_dir
|
||||||
|
@ -208,12 +208,15 @@ class Exporter:
|
|||||||
self.file = file
|
self.file = file
|
||||||
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else tuple(tuple(x.shape) for x in y)
|
||||||
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
|
self.pretty_name = self.file.stem.replace('yolo', 'YOLO')
|
||||||
|
description = f'Ultralytics {self.pretty_name} model' + f'trained on {Path(self.args.data).name}' \
|
||||||
|
if self.args.data else '(untrained)'
|
||||||
self.metadata = {
|
self.metadata = {
|
||||||
'description': f'Ultralytics {self.pretty_name} model trained on {Path(self.args.data).name}',
|
'description': description,
|
||||||
'author': 'Ultralytics',
|
'author': 'Ultralytics',
|
||||||
'license': 'GPL-3.0 https://ultralytics.com/license',
|
'license': 'GPL-3.0 https://ultralytics.com/license',
|
||||||
'version': __version__,
|
'version': __version__,
|
||||||
'stride': int(max(model.stride)),
|
'stride': int(max(model.stride)),
|
||||||
|
'task': model.task,
|
||||||
'names': model.names} # model metadata
|
'names': model.names} # model metadata
|
||||||
|
|
||||||
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
|
LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with input shape {tuple(im.shape)} BCHW and "
|
||||||
|
@ -9,22 +9,22 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
|
|||||||
guess_model_task, nn)
|
guess_model_task, nn)
|
||||||
from ultralytics.yolo.cfg import get_cfg
|
from ultralytics.yolo.cfg import get_cfg
|
||||||
from ultralytics.yolo.engine.exporter import Exporter
|
from ultralytics.yolo.engine.exporter import Exporter
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, yaml_load
|
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, yaml_load
|
||||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml
|
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml
|
||||||
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
||||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||||
|
|
||||||
# Map head to model, trainer, validator, and predictor classes
|
# Map head to model, trainer, validator, and predictor classes
|
||||||
MODEL_MAP = {
|
TASK_MAP = {
|
||||||
'classify': [
|
'classify': [
|
||||||
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
|
ClassificationModel, yolo.v8.classify.ClassificationTrainer, yolo.v8.classify.ClassificationValidator,
|
||||||
'yolo.TYPE.classify.ClassificationPredictor'],
|
yolo.v8.classify.ClassificationPredictor],
|
||||||
'detect': [
|
'detect': [
|
||||||
DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
|
DetectionModel, yolo.v8.detect.DetectionTrainer, yolo.v8.detect.DetectionValidator,
|
||||||
'yolo.TYPE.detect.DetectionPredictor'],
|
yolo.v8.detect.DetectionPredictor],
|
||||||
'segment': [
|
'segment': [
|
||||||
SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
|
SegmentationModel, yolo.v8.segment.SegmentationTrainer, yolo.v8.segment.SegmentationValidator,
|
||||||
'yolo.TYPE.segment.SegmentationPredictor']}
|
yolo.v8.segment.SegmentationPredictor]}
|
||||||
|
|
||||||
|
|
||||||
class YOLO:
|
class YOLO:
|
||||||
@ -33,52 +33,48 @@ class YOLO:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (str, Path): Path to the model file to load or create.
|
model (str, Path): Path to the model file to load or create.
|
||||||
type (str): Type/version of models to use. Defaults to "v8".
|
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
type (str): Type/version of models being used.
|
predictor (Any): The predictor object.
|
||||||
ModelClass (Any): Model class.
|
model (Any): The model object.
|
||||||
TrainerClass (Any): Trainer class.
|
trainer (Any): The trainer object.
|
||||||
ValidatorClass (Any): Validator class.
|
task (str): The type of model task.
|
||||||
PredictorClass (Any): Predictor class.
|
ckpt (Any): The checkpoint object if the model loaded from *.pt file.
|
||||||
predictor (Any): Predictor object.
|
cfg (str): The model configuration if loaded from *.yaml file.
|
||||||
model (Any): Model object.
|
ckpt_path (str): The checkpoint file path.
|
||||||
trainer (Any): Trainer object.
|
overrides (dict): Overrides for the trainer object.
|
||||||
task (str): Type of model task.
|
metrics_data (Any): The data for metrics.
|
||||||
ckpt (Any): Checkpoint object if model loaded from *.pt file.
|
|
||||||
cfg (str): Model configuration if loaded from *.yaml file.
|
|
||||||
ckpt_path (str): Checkpoint file path.
|
|
||||||
overrides (dict): Overrides for trainer object.
|
|
||||||
metrics_data (Any): Data for metrics.
|
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
__call__(): Alias for predict method.
|
__call__(source=None, stream=False, **kwargs):
|
||||||
_new(cfg, verbose=True): Initializes a new model and infers the task type from the model definitions.
|
Alias for the predict method.
|
||||||
_load(weights): Initializes a new model and infers the task type from the model head.
|
_new(cfg:str, verbose:bool=True) -> None:
|
||||||
_check_is_pytorch_model(): Raises TypeError if model is not a PyTorch model.
|
Initializes a new model and infers the task type from the model definitions.
|
||||||
reset(): Resets the model modules.
|
_load(weights:str, task:str='') -> None:
|
||||||
info(verbose=False): Logs model info.
|
Initializes a new model and infers the task type from the model head.
|
||||||
fuse(): Fuse model for faster inference.
|
_check_is_pytorch_model() -> None:
|
||||||
predict(source=None, stream=False, **kwargs): Perform prediction using the YOLO model.
|
Raises TypeError if the model is not a PyTorch model.
|
||||||
|
reset() -> None:
|
||||||
|
Resets the model modules.
|
||||||
|
info(verbose:bool=False) -> None:
|
||||||
|
Logs the model info.
|
||||||
|
fuse() -> None:
|
||||||
|
Fuses the model for faster inference.
|
||||||
|
predict(source=None, stream=False, **kwargs) -> List[ultralytics.yolo.engine.results.Results]:
|
||||||
|
Performs prediction using the YOLO model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
list[ultralytics.yolo.engine.results.Results]: The prediction results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model='yolov8n.pt', type='v8') -> None:
|
def __init__(self, model='yolov8n.pt') -> None:
|
||||||
"""
|
"""
|
||||||
Initializes the YOLO model.
|
Initializes the YOLO model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (str, Path): model to load or create
|
model (str, Path): model to load or create
|
||||||
type (str): Type/version of models to use. Defaults to "v8".
|
|
||||||
"""
|
"""
|
||||||
self._reset_callbacks()
|
self._reset_callbacks()
|
||||||
self.type = type
|
|
||||||
self.ModelClass = None # model class
|
|
||||||
self.TrainerClass = None # trainer class
|
|
||||||
self.ValidatorClass = None # validator class
|
|
||||||
self.PredictorClass = None # predictor class
|
|
||||||
self.predictor = None # reuse predictor
|
self.predictor = None # reuse predictor
|
||||||
self.model = None # model object
|
self.model = None # model object
|
||||||
self.trainer = None # trainer object
|
self.trainer = None # trainer object
|
||||||
@ -101,6 +97,10 @@ class YOLO:
|
|||||||
def __call__(self, source=None, stream=False, **kwargs):
|
def __call__(self, source=None, stream=False, **kwargs):
|
||||||
return self.predict(source, stream, **kwargs)
|
return self.predict(source, stream, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
name = self.__class__.__name__
|
||||||
|
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||||
|
|
||||||
def _new(self, cfg: str, verbose=True):
|
def _new(self, cfg: str, verbose=True):
|
||||||
"""
|
"""
|
||||||
Initializes a new model and infers the task type from the model definitions.
|
Initializes a new model and infers the task type from the model definitions.
|
||||||
@ -112,11 +112,15 @@ class YOLO:
|
|||||||
self.cfg = check_yaml(cfg) # check YAML
|
self.cfg = check_yaml(cfg) # check YAML
|
||||||
cfg_dict = yaml_load(self.cfg, append_filename=True) # model dict
|
cfg_dict = yaml_load(self.cfg, append_filename=True) # model dict
|
||||||
self.task = guess_model_task(cfg_dict)
|
self.task = guess_model_task(cfg_dict)
|
||||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model
|
||||||
self.model = self.ModelClass(cfg_dict, verbose=verbose and RANK == -1) # initialize
|
|
||||||
self.overrides['model'] = self.cfg
|
self.overrides['model'] = self.cfg
|
||||||
|
|
||||||
def _load(self, weights: str):
|
# Below added to allow export from yamls
|
||||||
|
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
|
||||||
|
self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||||
|
self.model.task = self.task
|
||||||
|
|
||||||
|
def _load(self, weights: str, task=''):
|
||||||
"""
|
"""
|
||||||
Initializes a new model and infers the task type from the model head.
|
Initializes a new model and infers the task type from the model head.
|
||||||
|
|
||||||
@ -127,8 +131,7 @@ class YOLO:
|
|||||||
if suffix == '.pt':
|
if suffix == '.pt':
|
||||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||||
self.task = self.model.args['task']
|
self.task = self.model.args['task']
|
||||||
self.overrides = self.model.args
|
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
||||||
self._reset_ckpt_args(self.overrides)
|
|
||||||
self.ckpt_path = self.model.pt_path
|
self.ckpt_path = self.model.pt_path
|
||||||
else:
|
else:
|
||||||
weights = check_file(weights)
|
weights = check_file(weights)
|
||||||
@ -136,7 +139,6 @@ class YOLO:
|
|||||||
self.task = guess_model_task(weights)
|
self.task = guess_model_task(weights)
|
||||||
self.ckpt_path = weights
|
self.ckpt_path = weights
|
||||||
self.overrides['model'] = weights
|
self.overrides['model'] = weights
|
||||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
|
||||||
|
|
||||||
def _check_is_pytorch_model(self):
|
def _check_is_pytorch_model(self):
|
||||||
"""
|
"""
|
||||||
@ -189,12 +191,13 @@ class YOLO:
|
|||||||
"""
|
"""
|
||||||
overrides = self.overrides.copy()
|
overrides = self.overrides.copy()
|
||||||
overrides['conf'] = 0.25
|
overrides['conf'] = 0.25
|
||||||
overrides.update(kwargs)
|
overrides.update(kwargs) # prefer kwargs
|
||||||
overrides['mode'] = kwargs.get('mode', 'predict')
|
overrides['mode'] = kwargs.get('mode', 'predict')
|
||||||
assert overrides['mode'] in ['track', 'predict']
|
assert overrides['mode'] in ['track', 'predict']
|
||||||
overrides['save'] = kwargs.get('save', False) # not save files by default
|
overrides['save'] = kwargs.get('save', False) # not save files by default
|
||||||
if not self.predictor:
|
if not self.predictor:
|
||||||
self.predictor = self.PredictorClass(overrides=overrides)
|
self.task = overrides.get('task') or self.task
|
||||||
|
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
|
||||||
self.predictor.setup_model(model=self.model)
|
self.predictor.setup_model(model=self.model)
|
||||||
else: # only update args if predictor is already setup
|
else: # only update args if predictor is already setup
|
||||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||||
@ -226,12 +229,15 @@ class YOLO:
|
|||||||
overrides['mode'] = 'val'
|
overrides['mode'] = 'val'
|
||||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||||
args.data = data or args.data
|
args.data = data or args.data
|
||||||
|
if 'task' in overrides:
|
||||||
|
self.task = args.task
|
||||||
|
else:
|
||||||
args.task = self.task
|
args.task = self.task
|
||||||
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
|
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
|
||||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
||||||
|
|
||||||
validator = self.ValidatorClass(args=args)
|
validator = TASK_MAP[self.task][2](args=args)
|
||||||
validator(model=self.model)
|
validator(model=self.model)
|
||||||
self.metrics_data = validator.metrics
|
self.metrics_data = validator.metrics
|
||||||
|
|
||||||
@ -267,8 +273,7 @@ class YOLO:
|
|||||||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||||
if args.batch == DEFAULT_CFG.batch:
|
if args.batch == DEFAULT_CFG.batch:
|
||||||
args.batch = 1 # default to 1 if not modified
|
args.batch = 1 # default to 1 if not modified
|
||||||
exporter = Exporter(overrides=args)
|
return Exporter(overrides=args)(model=self.model)
|
||||||
return exporter(model=self.model)
|
|
||||||
|
|
||||||
def train(self, **kwargs):
|
def train(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -282,15 +287,15 @@ class YOLO:
|
|||||||
overrides.update(kwargs)
|
overrides.update(kwargs)
|
||||||
if kwargs.get('cfg'):
|
if kwargs.get('cfg'):
|
||||||
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
||||||
overrides = yaml_load(check_yaml(kwargs['cfg']), append_filename=True)
|
overrides = yaml_load(check_yaml(kwargs['cfg']))
|
||||||
overrides['task'] = self.task
|
|
||||||
overrides['mode'] = 'train'
|
overrides['mode'] = 'train'
|
||||||
if not overrides.get('data'):
|
if not overrides.get('data'):
|
||||||
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
||||||
if overrides.get('resume'):
|
if overrides.get('resume'):
|
||||||
overrides['resume'] = self.ckpt_path
|
overrides['resume'] = self.ckpt_path
|
||||||
|
|
||||||
self.trainer = self.TrainerClass(overrides=overrides)
|
self.task = overrides.get('task') or self.task
|
||||||
|
self.trainer = TASK_MAP[self.task][1](overrides=overrides)
|
||||||
if not overrides.get('resume'): # manually set model only if not resuming
|
if not overrides.get('resume'): # manually set model only if not resuming
|
||||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||||
self.model = self.trainer.model
|
self.model = self.trainer.model
|
||||||
@ -311,13 +316,6 @@ class YOLO:
|
|||||||
self._check_is_pytorch_model()
|
self._check_is_pytorch_model()
|
||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
|
|
||||||
def _assign_ops_from_task(self):
|
|
||||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
|
|
||||||
trainer_class = eval(train_lit.replace('TYPE', f'{self.type}'))
|
|
||||||
validator_class = eval(val_lit.replace('TYPE', f'{self.type}'))
|
|
||||||
predictor_class = eval(pred_lit.replace('TYPE', f'{self.type}'))
|
|
||||||
return model_class, trainer_class, validator_class, predictor_class
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self):
|
def names(self):
|
||||||
"""
|
"""
|
||||||
@ -357,9 +355,8 @@ class YOLO:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reset_ckpt_args(args):
|
def _reset_ckpt_args(args):
|
||||||
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
|
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
||||||
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset', 'simplify':
|
return {k: v for k, v in args.items() if k in include}
|
||||||
args.pop(arg, None)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reset_callbacks():
|
def _reset_callbacks():
|
||||||
|
@ -108,7 +108,6 @@ class BasePredictor:
|
|||||||
def postprocess(self, preds, img, orig_img):
|
def postprocess(self, preds, img, orig_img):
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
@smart_inference_mode()
|
|
||||||
def __call__(self, source=None, model=None, stream=False):
|
def __call__(self, source=None, model=None, stream=False):
|
||||||
if stream:
|
if stream:
|
||||||
return self.stream_inference(source, model)
|
return self.stream_inference(source, model)
|
||||||
@ -136,6 +135,7 @@ class BasePredictor:
|
|||||||
self.source_type = self.dataset.source_type
|
self.source_type = self.dataset.source_type
|
||||||
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
|
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
|
||||||
|
|
||||||
|
@smart_inference_mode()
|
||||||
def stream_inference(self, source=None, model=None):
|
def stream_inference(self, source=None, model=None):
|
||||||
if self.args.verbose:
|
if self.args.verbose:
|
||||||
LOGGER.info('')
|
LOGGER.info('')
|
||||||
@ -161,12 +161,14 @@ class BasePredictor:
|
|||||||
self.batch = batch
|
self.batch = batch
|
||||||
path, im, im0s, vid_cap, s = batch
|
path, im, im0s, vid_cap, s = batch
|
||||||
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
||||||
|
|
||||||
|
# preprocess
|
||||||
with self.dt[0]:
|
with self.dt[0]:
|
||||||
im = self.preprocess(im)
|
im = self.preprocess(im)
|
||||||
if len(im.shape) == 3:
|
if len(im.shape) == 3:
|
||||||
im = im[None] # expand for batch dim
|
im = im[None] # expand for batch dim
|
||||||
|
|
||||||
# Inference
|
# inference
|
||||||
with self.dt[1]:
|
with self.dt[1]:
|
||||||
preds = self.model(im, augment=self.args.augment, visualize=visualize)
|
preds = self.model(im, augment=self.args.augment, visualize=visualize)
|
||||||
|
|
||||||
|
@ -21,26 +21,30 @@ class Results:
|
|||||||
A class for storing and manipulating inference results.
|
A class for storing and manipulating inference results.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||||
masks (Masks, optional): A Masks object containing the detection masks.
|
path (str): The path to the image file.
|
||||||
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
names (List[str]): A list of class names.
|
||||||
orig_img (tuple, optional): Original image size.
|
boxes (List[List[float]], optional): A list of bounding box coordinates for each detection.
|
||||||
|
masks (numpy.ndarray, optional): A 3D numpy array of detection masks, where each mask is a binary image.
|
||||||
|
probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
orig_img (numpy.ndarray): The original image as a numpy array.
|
||||||
|
orig_shape (tuple): The original image shape in (height, width) format.
|
||||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||||
masks (Masks, optional): A Masks object containing the detection masks.
|
masks (Masks, optional): A Masks object containing the detection masks.
|
||||||
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
probs (numpy.ndarray, optional): A 2D numpy array of detection probabilities for each class.
|
||||||
orig_img (tuple, optional): Original image size.
|
names (List[str]): A list of class names.
|
||||||
data (torch.Tensor): The raw masks tensor
|
path (str): The path to the image file.
|
||||||
|
_keys (tuple): A tuple of attribute names for non-empty attributes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None) -> None:
|
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None) -> None:
|
||||||
self.orig_img = orig_img
|
self.orig_img = orig_img
|
||||||
self.orig_shape = orig_img.shape[:2]
|
self.orig_shape = orig_img.shape[:2]
|
||||||
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
|
self.boxes = Boxes(boxes.cpu(), self.orig_shape) if boxes is not None else None # native size boxes
|
||||||
self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks
|
self.masks = Masks(masks.cpu(), self.orig_shape) if masks is not None else None # native size or imgsz masks
|
||||||
self.probs = probs if probs is not None else None
|
self.probs = probs.cpu() if probs is not None else None
|
||||||
self.names = names
|
self.names = names
|
||||||
self.path = path
|
self.path = path
|
||||||
self._keys = (k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None)
|
self._keys = (k for k in ('boxes', 'masks', 'probs') if getattr(self, k) is not None)
|
||||||
@ -99,24 +103,22 @@ class Results:
|
|||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
name = self.__class__.__name__
|
name = self.__class__.__name__
|
||||||
raise AttributeError(f"""
|
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||||
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
|
||||||
masks (Masks, optional): A Masks object containing the detection masks.
|
|
||||||
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
|
||||||
orig_shape (tuple, optional): Original image size.
|
|
||||||
""")
|
|
||||||
|
|
||||||
def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
def plot(self, show_conf=True, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
||||||
"""
|
"""
|
||||||
Plots the given result on an input RGB image. Accepts cv2(numpy) or PIL Image
|
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
show_conf (bool): Show confidence
|
show_conf (bool): Whether to show the detection confidence score.
|
||||||
line_width (Float): The line width of boxes. Automatically scaled to img size if not provided
|
line_width (float, optional): The line width of the bounding boxes. If None, it is automatically scaled to the image size.
|
||||||
font_size (Float): The font size of . Automatically scaled to img size if not provided
|
font_size (float, optional): The font size of the text. If None, it is automatically scaled to the image size.
|
||||||
|
font (str): The font to use for the text.
|
||||||
|
pil (bool): Whether to return the image as a PIL Image.
|
||||||
|
example (str): An example string to display in the plot. Useful for indicating the expected format of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None or PIL Image: If `pil` is True, the image will be returned as a PIL Image. Otherwise, nothing is returned.
|
||||||
"""
|
"""
|
||||||
img = deepcopy(self.orig_img)
|
img = deepcopy(self.orig_img)
|
||||||
annotator = Annotator(img, line_width, font_size, font, pil, example)
|
annotator = Annotator(img, line_width, font_size, font, pil, example)
|
||||||
@ -157,15 +159,24 @@ class Boxes:
|
|||||||
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
||||||
with shape (num_boxes, 6).
|
with shape (num_boxes, 6).
|
||||||
orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width).
|
orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width).
|
||||||
|
is_track (bool): True if the boxes also include track IDs, False otherwise.
|
||||||
|
|
||||||
Properties:
|
Properties:
|
||||||
xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format.
|
xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format.
|
||||||
conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes.
|
conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes.
|
||||||
cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes.
|
cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes.
|
||||||
|
id (torch.Tensor) or (numpy.ndarray): The track IDs of the boxes (if available).
|
||||||
xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format.
|
xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format.
|
||||||
xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size.
|
xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size.
|
||||||
xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size.
|
xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size.
|
||||||
data (torch.Tensor): The raw bboxes tensor
|
data (torch.Tensor): The raw bboxes tensor
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
cpu(): Move the object to CPU memory.
|
||||||
|
numpy(): Convert the object to a numpy array.
|
||||||
|
cuda(): Move the object to CUDA memory.
|
||||||
|
to(*args, **kwargs): Move the object to the specified device.
|
||||||
|
pandas(): Convert the object to a pandas DataFrame (not yet implemented).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, boxes, orig_shape) -> None:
|
def __init__(self, boxes, orig_shape) -> None:
|
||||||
@ -257,22 +268,7 @@ class Boxes:
|
|||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
name = self.__class__.__name__
|
name = self.__class__.__name__
|
||||||
raise AttributeError(f"""
|
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||||
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
|
||||||
with shape (num_boxes, 6).
|
|
||||||
orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width).
|
|
||||||
|
|
||||||
Properties:
|
|
||||||
xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format.
|
|
||||||
conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes.
|
|
||||||
cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes.
|
|
||||||
xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format.
|
|
||||||
xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size.
|
|
||||||
xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size.
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
class Masks:
|
class Masks:
|
||||||
@ -289,6 +285,17 @@ class Masks:
|
|||||||
|
|
||||||
Properties:
|
Properties:
|
||||||
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection masks.
|
segments (list): A list of segments which includes x, y, w, h, label, confidence, and mask of each detection masks.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
cpu(): Returns a copy of the masks tensor on CPU memory.
|
||||||
|
numpy(): Returns a copy of the masks tensor as a numpy array.
|
||||||
|
cuda(): Returns a copy of the masks tensor on GPU memory.
|
||||||
|
to(): Returns a copy of the masks tensor with the specified device and dtype.
|
||||||
|
__len__(): Returns the number of masks in the tensor.
|
||||||
|
__str__(): Returns a string representation of the masks tensor.
|
||||||
|
__repr__(): Returns a detailed string representation of the masks tensor.
|
||||||
|
__getitem__(): Returns a new Masks object with the masks at the specified index.
|
||||||
|
__getattr__(): Raises an AttributeError with a list of valid attributes and properties.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, masks, orig_shape) -> None:
|
def __init__(self, masks, orig_shape) -> None:
|
||||||
@ -337,13 +344,4 @@ class Masks:
|
|||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
name = self.__class__.__name__
|
name = self.__class__.__name__
|
||||||
raise AttributeError(f"""
|
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||||
'{name}' object has no attribute '{attr}'. Valid '{name}' object attributes and properties are:
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width).
|
|
||||||
orig_shape (tuple): Original image size, in the format (height, width).
|
|
||||||
|
|
||||||
Properties:
|
|
||||||
segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks.
|
|
||||||
""")
|
|
||||||
|
@ -44,7 +44,6 @@ class BaseTrainer:
|
|||||||
Attributes:
|
Attributes:
|
||||||
args (SimpleNamespace): Configuration for the trainer.
|
args (SimpleNamespace): Configuration for the trainer.
|
||||||
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
||||||
console (logging.Logger): Logger instance.
|
|
||||||
validator (BaseValidator): Validator instance.
|
validator (BaseValidator): Validator instance.
|
||||||
model (nn.Module): Model instance.
|
model (nn.Module): Model instance.
|
||||||
callbacks (defaultdict): Dictionary of callbacks.
|
callbacks (defaultdict): Dictionary of callbacks.
|
||||||
@ -84,7 +83,6 @@ class BaseTrainer:
|
|||||||
self.args = get_cfg(cfg, overrides)
|
self.args = get_cfg(cfg, overrides)
|
||||||
self.device = select_device(self.args.device, self.args.batch)
|
self.device = select_device(self.args.device, self.args.batch)
|
||||||
self.check_resume()
|
self.check_resume()
|
||||||
self.console = LOGGER
|
|
||||||
self.validator = None
|
self.validator = None
|
||||||
self.model = None
|
self.model = None
|
||||||
self.metrics = None
|
self.metrics = None
|
||||||
@ -180,11 +178,12 @@ class BaseTrainer:
|
|||||||
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
if world_size > 1 and 'LOCAL_RANK' not in os.environ:
|
||||||
cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans
|
cmd, file = generate_ddp_command(world_size, self) # security vulnerability in Snyk scans
|
||||||
try:
|
try:
|
||||||
|
LOGGER.info(f'Running DDP command {cmd}')
|
||||||
subprocess.run(cmd, check=True)
|
subprocess.run(cmd, check=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.console.warning(e)
|
LOGGER.warning(e)
|
||||||
finally:
|
finally:
|
||||||
ddp_cleanup(self, file)
|
ddp_cleanup(self, str(file))
|
||||||
else:
|
else:
|
||||||
self._do_train(RANK, world_size)
|
self._do_train(RANK, world_size)
|
||||||
|
|
||||||
@ -193,7 +192,7 @@ class BaseTrainer:
|
|||||||
# os.environ['MASTER_PORT'] = '9020'
|
# os.environ['MASTER_PORT'] = '9020'
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
self.device = torch.device('cuda', rank)
|
self.device = torch.device('cuda', rank)
|
||||||
self.console.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
LOGGER.info(f'DDP settings: RANK {rank}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||||
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size)
|
dist.init_process_group('nccl' if dist.is_nccl_available() else 'gloo', rank=rank, world_size=world_size)
|
||||||
|
|
||||||
def _setup_train(self, rank, world_size):
|
def _setup_train(self, rank, world_size):
|
||||||
@ -262,7 +261,7 @@ class BaseTrainer:
|
|||||||
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
||||||
last_opt_step = -1
|
last_opt_step = -1
|
||||||
self.run_callbacks('on_train_start')
|
self.run_callbacks('on_train_start')
|
||||||
self.log(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||||
f'Starting training for {self.epochs} epochs...')
|
f'Starting training for {self.epochs} epochs...')
|
||||||
@ -278,14 +277,14 @@ class BaseTrainer:
|
|||||||
pbar = enumerate(self.train_loader)
|
pbar = enumerate(self.train_loader)
|
||||||
# Update dataloader attributes (optional)
|
# Update dataloader attributes (optional)
|
||||||
if epoch == (self.epochs - self.args.close_mosaic):
|
if epoch == (self.epochs - self.args.close_mosaic):
|
||||||
self.console.info('Closing dataloader mosaic')
|
LOGGER.info('Closing dataloader mosaic')
|
||||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||||
self.train_loader.dataset.mosaic = False
|
self.train_loader.dataset.mosaic = False
|
||||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||||
|
|
||||||
if rank in {-1, 0}:
|
if rank in {-1, 0}:
|
||||||
self.console.info(self.progress_string())
|
LOGGER.info(self.progress_string())
|
||||||
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
|
pbar = tqdm(enumerate(self.train_loader), total=nb, bar_format=TQDM_BAR_FORMAT)
|
||||||
self.tloss = None
|
self.tloss = None
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
@ -372,12 +371,11 @@ class BaseTrainer:
|
|||||||
|
|
||||||
if rank in {-1, 0}:
|
if rank in {-1, 0}:
|
||||||
# Do final val with best.pt
|
# Do final val with best.pt
|
||||||
self.log(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
LOGGER.info(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
||||||
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
||||||
self.final_eval()
|
self.final_eval()
|
||||||
if self.args.plots:
|
if self.args.plots:
|
||||||
self.plot_metrics()
|
self.plot_metrics()
|
||||||
self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
|
|
||||||
self.run_callbacks('on_train_end')
|
self.run_callbacks('on_train_end')
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
self.run_callbacks('teardown')
|
self.run_callbacks('teardown')
|
||||||
@ -450,18 +448,6 @@ class BaseTrainer:
|
|||||||
self.best_fitness = fitness
|
self.best_fitness = fitness
|
||||||
return metrics, fitness
|
return metrics, fitness
|
||||||
|
|
||||||
def log(self, text, rank=-1):
|
|
||||||
"""
|
|
||||||
Logs the given text to given ranks process if provided, otherwise logs to all ranks.
|
|
||||||
|
|
||||||
Args"
|
|
||||||
text (str): text to log
|
|
||||||
rank (List[Int]): process rank
|
|
||||||
|
|
||||||
"""
|
|
||||||
if rank in {-1, 0}:
|
|
||||||
self.console.info(text)
|
|
||||||
|
|
||||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||||
|
|
||||||
@ -521,7 +507,7 @@ class BaseTrainer:
|
|||||||
if f.exists():
|
if f.exists():
|
||||||
strip_optimizer(f) # strip optimizers
|
strip_optimizer(f) # strip optimizers
|
||||||
if f is self.best:
|
if f is self.best:
|
||||||
self.console.info(f'\nValidating {f}...')
|
LOGGER.info(f'\nValidating {f}...')
|
||||||
self.metrics = self.validator(model=f)
|
self.metrics = self.validator(model=f)
|
||||||
self.metrics.pop('fitness', None)
|
self.metrics.pop('fitness', None)
|
||||||
self.run_callbacks('on_fit_epoch_end')
|
self.run_callbacks('on_fit_epoch_end')
|
||||||
@ -564,7 +550,7 @@ class BaseTrainer:
|
|||||||
self.best_fitness = best_fitness
|
self.best_fitness = best_fitness
|
||||||
self.start_epoch = start_epoch
|
self.start_epoch = start_epoch
|
||||||
if start_epoch > (self.epochs - self.args.close_mosaic):
|
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||||
self.console.info('Closing dataloader mosaic')
|
LOGGER.info('Closing dataloader mosaic')
|
||||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||||
self.train_loader.dataset.mosaic = False
|
self.train_loader.dataset.mosaic = False
|
||||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||||
|
@ -44,7 +44,6 @@ class BaseValidator:
|
|||||||
Attributes:
|
Attributes:
|
||||||
dataloader (DataLoader): Dataloader to use for validation.
|
dataloader (DataLoader): Dataloader to use for validation.
|
||||||
pbar (tqdm): Progress bar to update during validation.
|
pbar (tqdm): Progress bar to update during validation.
|
||||||
logger (logging.Logger): Logger to use for validation.
|
|
||||||
args (SimpleNamespace): Configuration for the validator.
|
args (SimpleNamespace): Configuration for the validator.
|
||||||
model (nn.Module): Model to validate.
|
model (nn.Module): Model to validate.
|
||||||
data (dict): Data dictionary.
|
data (dict): Data dictionary.
|
||||||
@ -56,7 +55,7 @@ class BaseValidator:
|
|||||||
save_dir (Path): Directory to save results.
|
save_dir (Path): Directory to save results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
|
||||||
"""
|
"""
|
||||||
Initializes a BaseValidator instance.
|
Initializes a BaseValidator instance.
|
||||||
|
|
||||||
@ -69,14 +68,13 @@ class BaseValidator:
|
|||||||
"""
|
"""
|
||||||
self.dataloader = dataloader
|
self.dataloader = dataloader
|
||||||
self.pbar = pbar
|
self.pbar = pbar
|
||||||
self.logger = logger or LOGGER
|
|
||||||
self.args = args or get_cfg(DEFAULT_CFG)
|
self.args = args or get_cfg(DEFAULT_CFG)
|
||||||
self.model = None
|
self.model = None
|
||||||
self.data = None
|
self.data = None
|
||||||
self.device = None
|
self.device = None
|
||||||
self.batch_i = None
|
self.batch_i = None
|
||||||
self.training = True
|
self.training = True
|
||||||
self.speed = None
|
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
|
||||||
self.jdict = None
|
self.jdict = None
|
||||||
|
|
||||||
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
|
||||||
@ -123,14 +121,14 @@ class BaseValidator:
|
|||||||
self.device = model.device
|
self.device = model.device
|
||||||
if not pt and not jit:
|
if not pt and not jit:
|
||||||
self.args.batch = 1 # export.py models default to batch-size 1
|
self.args.batch = 1 # export.py models default to batch-size 1
|
||||||
self.logger.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
|
||||||
|
|
||||||
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
|
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
|
||||||
self.data = check_det_dataset(self.args.data)
|
self.data = check_det_dataset(self.args.data)
|
||||||
elif self.args.task == 'classify':
|
elif self.args.task == 'classify':
|
||||||
self.data = check_cls_dataset(self.args.data)
|
self.data = check_cls_dataset(self.args.data)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' not found ❌"))
|
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||||
|
|
||||||
if self.device.type == 'cpu':
|
if self.device.type == 'cpu':
|
||||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||||
@ -179,7 +177,7 @@ class BaseValidator:
|
|||||||
stats = self.get_stats()
|
stats = self.get_stats()
|
||||||
self.check_stats(stats)
|
self.check_stats(stats)
|
||||||
self.print_results()
|
self.print_results()
|
||||||
self.speed = tuple(x.t / len(self.dataloader.dataset) * 1E3 for x in dt) # speeds per image
|
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
|
||||||
self.finalize_metrics()
|
self.finalize_metrics()
|
||||||
self.run_callbacks('on_val_end')
|
self.run_callbacks('on_val_end')
|
||||||
if self.training:
|
if self.training:
|
||||||
@ -187,11 +185,11 @@ class BaseValidator:
|
|||||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
|
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
|
||||||
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||||
else:
|
else:
|
||||||
self.logger.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
|
LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
|
||||||
self.speed)
|
tuple(self.speed.values()))
|
||||||
if self.args.save_json and self.jdict:
|
if self.args.save_json and self.jdict:
|
||||||
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
|
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
|
||||||
self.logger.info(f'Saving {f.name}...')
|
LOGGER.info(f'Saving {f.name}...')
|
||||||
json.dump(self.jdict, f) # flatten and save
|
json.dump(self.jdict, f) # flatten and save
|
||||||
stats = self.eval_json(stats) # update stats
|
stats = self.eval_json(stats) # update stats
|
||||||
if self.args.plots or self.args.save_json:
|
if self.args.plots or self.args.save_json:
|
||||||
|
@ -60,7 +60,7 @@ def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
|
|||||||
|
|
||||||
# Export
|
# Export
|
||||||
if format == '-':
|
if format == '-':
|
||||||
filename = model.ckpt_path
|
filename = model.ckpt_path or model.cfg
|
||||||
export = model # PyTorch format
|
export = model # PyTorch format
|
||||||
else:
|
else:
|
||||||
filename = model.export(imgsz=imgsz, format=format, half=half, device=device) # all others
|
filename = model.export(imgsz=imgsz, format=format, half=half, device=device) # all others
|
||||||
|
@ -29,7 +29,7 @@ def on_pretrain_routine_start(trainer):
|
|||||||
auto_connect_frameworks={'pytorch': False})
|
auto_connect_frameworks={'pytorch': False})
|
||||||
task.connect(vars(trainer.args), name='General')
|
task.connect(vars(trainer.args), name='General')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f'WARNING ⚠️ ClearML not initialized correctly, not logging this run. {e}')
|
LOGGER.warning(f'WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}')
|
||||||
|
|
||||||
|
|
||||||
def on_train_epoch_end(trainer):
|
def on_train_epoch_end(trainer):
|
||||||
@ -41,9 +41,9 @@ def on_fit_epoch_end(trainer):
|
|||||||
task = Task.current_task()
|
task = Task.current_task()
|
||||||
if task and trainer.epoch == 0:
|
if task and trainer.epoch == 0:
|
||||||
model_info = {
|
model_info = {
|
||||||
'Parameters': get_num_params(trainer.model),
|
'model/parameters': get_num_params(trainer.model),
|
||||||
'GFLOPs': round(get_flops(trainer.model), 3),
|
'model/GFLOPs': round(get_flops(trainer.model), 3),
|
||||||
'Inference speed (ms/img)': round(trainer.validator.speed[1], 3)}
|
'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
|
||||||
task.connect(model_info, name='Model')
|
task.connect(model_info, name='Model')
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ def on_pretrain_routine_start(trainer):
|
|||||||
experiment = comet_ml.Experiment(project_name=trainer.args.project or 'YOLOv8')
|
experiment = comet_ml.Experiment(project_name=trainer.args.project or 'YOLOv8')
|
||||||
experiment.log_parameters(vars(trainer.args))
|
experiment.log_parameters(vars(trainer.args))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f'WARNING ⚠️ Comet not initialized correctly, not logging this run. {e}')
|
LOGGER.warning(f'WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}')
|
||||||
|
|
||||||
|
|
||||||
def on_train_epoch_end(trainer):
|
def on_train_epoch_end(trainer):
|
||||||
@ -36,7 +36,7 @@ def on_fit_epoch_end(trainer):
|
|||||||
model_info = {
|
model_info = {
|
||||||
'model/parameters': get_num_params(trainer.model),
|
'model/parameters': get_num_params(trainer.model),
|
||||||
'model/GFLOPs': round(get_flops(trainer.model), 3),
|
'model/GFLOPs': round(get_flops(trainer.model), 3),
|
||||||
'model/speed(ms)': round(trainer.validator.speed[1], 3)}
|
'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
|
||||||
experiment.log_metrics(model_info, step=trainer.epoch + 1)
|
experiment.log_metrics(model_info, step=trainer.epoch + 1)
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,17 +2,24 @@
|
|||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from ultralytics.yolo.utils import LOGGER
|
||||||
|
|
||||||
writer = None # TensorBoard SummaryWriter instance
|
writer = None # TensorBoard SummaryWriter instance
|
||||||
|
|
||||||
|
|
||||||
def _log_scalars(scalars, step=0):
|
def _log_scalars(scalars, step=0):
|
||||||
|
if writer:
|
||||||
for k, v in scalars.items():
|
for k, v in scalars.items():
|
||||||
writer.add_scalar(k, v, step)
|
writer.add_scalar(k, v, step)
|
||||||
|
|
||||||
|
|
||||||
def on_pretrain_routine_start(trainer):
|
def on_pretrain_routine_start(trainer):
|
||||||
global writer
|
global writer
|
||||||
|
try:
|
||||||
writer = SummaryWriter(str(trainer.save_dir))
|
writer = SummaryWriter(str(trainer.save_dir))
|
||||||
|
except Exception as e:
|
||||||
|
writer = None # TensorBoard SummaryWriter instance
|
||||||
|
LOGGER.warning(f'WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}')
|
||||||
|
|
||||||
|
|
||||||
def on_batch_end(trainer):
|
def on_batch_end(trainer):
|
||||||
|
@ -254,7 +254,7 @@ def check_file(file, suffix='', download=True):
|
|||||||
return file
|
return file
|
||||||
else: # search
|
else: # search
|
||||||
files = []
|
files = []
|
||||||
for d in 'models', 'datasets', 'tracker/cfg': # search directories
|
for d in 'models', 'datasets', 'tracker/cfg', 'yolo/cfg': # search directories
|
||||||
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
|
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
|
||||||
if not files:
|
if not files:
|
||||||
raise FileNotFoundError(f"'{file}' does not exist")
|
raise FileNotFoundError(f"'{file}' does not exist")
|
||||||
|
@ -51,10 +51,9 @@ def generate_ddp_command(world_size, trainer):
|
|||||||
file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0])
|
file = generate_ddp_file(trainer) # if argv[0].endswith('yolo') else os.path.abspath(argv[0])
|
||||||
|
|
||||||
# Build command
|
# Build command
|
||||||
torch_distributed_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
|
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
|
||||||
cmd = [
|
port = find_free_network_port()
|
||||||
sys.executable, '-m', torch_distributed_cmd, '--nproc_per_node', f'{world_size}', '--master_port',
|
cmd = [sys.executable, '-m', dist_cmd, '--nproc_per_node', f'{world_size}', '--master_port', f'{port}', file] + args
|
||||||
f'{find_free_network_port()}', file] + args
|
|
||||||
return cmd, file
|
return cmd, file
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ultralytics.yolo.utils import LOGGER
|
from ultralytics.yolo.utils import LOGGER, checks
|
||||||
|
|
||||||
GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \
|
GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \
|
||||||
[f'yolov5{size}u.pt' for size in 'nsmlx'] + \
|
[f'yolov5{size}u.pt' for size in 'nsmlx'] + \
|
||||||
@ -87,7 +87,7 @@ def safe_download(url,
|
|||||||
try:
|
try:
|
||||||
if curl or i > 0: # curl download with retry, continue
|
if curl or i > 0: # curl download with retry, continue
|
||||||
s = 'sS' * (not progress) # silent
|
s = 'sS' * (not progress) # silent
|
||||||
r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '9', '-C', '-']).returncode
|
r = subprocess.run(['curl', '-#', f'-{s}L', url, '-o', f, '--retry', '3', '-C', '-']).returncode
|
||||||
assert r == 0, f'Curl return value {r}'
|
assert r == 0, f'Curl return value {r}'
|
||||||
else: # urllib download
|
else: # urllib download
|
||||||
method = 'torch'
|
method = 'torch'
|
||||||
@ -112,8 +112,10 @@ def safe_download(url,
|
|||||||
break # success
|
break # success
|
||||||
f.unlink() # remove partial downloads
|
f.unlink() # remove partial downloads
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if i >= retry:
|
if i == 0 and not checks.check_online():
|
||||||
raise ConnectionError(f'❌ Download failure for {url}') from e
|
raise ConnectionError(f'❌ Download failure for {url}. Environment is not online.') from e
|
||||||
|
elif i >= retry:
|
||||||
|
raise ConnectionError(f'❌ Download failure for {url}. Retry limit reached.') from e
|
||||||
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
||||||
|
|
||||||
if unzip and f.exists() and f.suffix in {'.zip', '.tar', '.gz'}:
|
if unzip and f.exists() and f.suffix in {'.zip', '.tar', '.gz'}:
|
||||||
|
@ -7,7 +7,7 @@ from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
|
|||||||
from ultralytics.yolo import v8
|
from ultralytics.yolo import v8
|
||||||
from ultralytics.yolo.data import build_classification_dataloader
|
from ultralytics.yolo.data import build_classification_dataloader
|
||||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG, RANK
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
|
||||||
from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer
|
from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer
|
||||||
|
|
||||||
|
|
||||||
@ -64,6 +64,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
||||||
else:
|
else:
|
||||||
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
||||||
|
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
|
||||||
|
|
||||||
return # dont return ckpt. Classification doesn't support resume
|
return # dont return ckpt. Classification doesn't support resume
|
||||||
|
|
||||||
@ -93,7 +94,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
self.loss_names = ['loss']
|
self.loss_names = ['loss']
|
||||||
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
|
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
|
||||||
|
|
||||||
def criterion(self, preds, batch):
|
def criterion(self, preds, batch):
|
||||||
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
|
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
|
||||||
@ -132,11 +133,12 @@ class ClassificationTrainer(BaseTrainer):
|
|||||||
strip_optimizer(f) # strip optimizers
|
strip_optimizer(f) # strip optimizers
|
||||||
# TODO: validate best.pt after training completes
|
# TODO: validate best.pt after training completes
|
||||||
# if f is self.best:
|
# if f is self.best:
|
||||||
# self.console.info(f'\nValidating {f}...')
|
# LOGGER.info(f'\nValidating {f}...')
|
||||||
# self.validator.args.save_json = True
|
# self.validator.args.save_json = True
|
||||||
# self.metrics = self.validator(model=f)
|
# self.metrics = self.validator(model=f)
|
||||||
# self.metrics.pop('fitness', None)
|
# self.metrics.pop('fitness', None)
|
||||||
# self.run_callbacks('on_fit_epoch_end')
|
# self.run_callbacks('on_fit_epoch_end')
|
||||||
|
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||||
|
|
||||||
|
|
||||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||||
|
@ -2,14 +2,14 @@
|
|||||||
|
|
||||||
from ultralytics.yolo.data import build_classification_dataloader
|
from ultralytics.yolo.data import build_classification_dataloader
|
||||||
from ultralytics.yolo.engine.validator import BaseValidator
|
from ultralytics.yolo.engine.validator import BaseValidator
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER
|
||||||
from ultralytics.yolo.utils.metrics import ClassifyMetrics
|
from ultralytics.yolo.utils.metrics import ClassifyMetrics
|
||||||
|
|
||||||
|
|
||||||
class ClassificationValidator(BaseValidator):
|
class ClassificationValidator(BaseValidator):
|
||||||
|
|
||||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
|
||||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
super().__init__(dataloader, save_dir, pbar, args)
|
||||||
self.args.task = 'classify'
|
self.args.task = 'classify'
|
||||||
self.metrics = ClassifyMetrics()
|
self.metrics = ClassifyMetrics()
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ class ClassificationValidator(BaseValidator):
|
|||||||
self.targets.append(batch['cls'])
|
self.targets.append(batch['cls'])
|
||||||
|
|
||||||
def finalize_metrics(self, *args, **kwargs):
|
def finalize_metrics(self, *args, **kwargs):
|
||||||
self.metrics.speed = dict(zip(self.metrics.speed.keys(), self.speed))
|
self.metrics.speed = self.speed
|
||||||
|
|
||||||
def get_stats(self):
|
def get_stats(self):
|
||||||
self.metrics.process(self.targets, self.pred)
|
self.metrics.process(self.targets, self.pred)
|
||||||
@ -45,7 +45,7 @@ class ClassificationValidator(BaseValidator):
|
|||||||
|
|
||||||
def print_results(self):
|
def print_results(self):
|
||||||
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
|
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
|
||||||
self.logger.info(pf % ('all', self.metrics.top1, self.metrics.top5))
|
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
|
||||||
|
|
||||||
|
|
||||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||||
|
@ -66,10 +66,7 @@ class DetectionTrainer(BaseTrainer):
|
|||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
|
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
|
||||||
return v8.detect.DetectionValidator(self.test_loader,
|
return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||||
save_dir=self.save_dir,
|
|
||||||
logger=self.console,
|
|
||||||
args=copy(self.args))
|
|
||||||
|
|
||||||
def criterion(self, preds, batch):
|
def criterion(self, preds, batch):
|
||||||
if not hasattr(self, 'compute_loss'):
|
if not hasattr(self, 'compute_loss'):
|
||||||
|
@ -9,7 +9,7 @@ import torch
|
|||||||
from ultralytics.yolo.data import build_dataloader
|
from ultralytics.yolo.data import build_dataloader
|
||||||
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
|
from ultralytics.yolo.data.dataloaders.v5loader import create_dataloader
|
||||||
from ultralytics.yolo.engine.validator import BaseValidator
|
from ultralytics.yolo.engine.validator import BaseValidator
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG, colorstr, ops
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, colorstr, ops
|
||||||
from ultralytics.yolo.utils.checks import check_requirements
|
from ultralytics.yolo.utils.checks import check_requirements
|
||||||
from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
from ultralytics.yolo.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
||||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||||
@ -18,8 +18,8 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
|
|||||||
|
|
||||||
class DetectionValidator(BaseValidator):
|
class DetectionValidator(BaseValidator):
|
||||||
|
|
||||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
|
||||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
super().__init__(dataloader, save_dir, pbar, args)
|
||||||
self.args.task = 'detect'
|
self.args.task = 'detect'
|
||||||
self.is_coco = False
|
self.is_coco = False
|
||||||
self.class_map = None
|
self.class_map = None
|
||||||
@ -112,7 +112,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||||
|
|
||||||
def finalize_metrics(self, *args, **kwargs):
|
def finalize_metrics(self, *args, **kwargs):
|
||||||
self.metrics.speed = dict(zip(self.metrics.speed.keys(), self.speed))
|
self.metrics.speed = self.speed
|
||||||
|
|
||||||
def get_stats(self):
|
def get_stats(self):
|
||||||
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
|
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
|
||||||
@ -123,15 +123,15 @@ class DetectionValidator(BaseValidator):
|
|||||||
|
|
||||||
def print_results(self):
|
def print_results(self):
|
||||||
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
|
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
|
||||||
self.logger.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
|
||||||
if self.nt_per_class.sum() == 0:
|
if self.nt_per_class.sum() == 0:
|
||||||
self.logger.warning(
|
LOGGER.warning(
|
||||||
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
|
f'WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels')
|
||||||
|
|
||||||
# Print results per class
|
# Print results per class
|
||||||
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
|
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
|
||||||
for i, c in enumerate(self.metrics.ap_class_index):
|
for i, c in enumerate(self.metrics.ap_class_index):
|
||||||
self.logger.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
|
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
|
||||||
|
|
||||||
if self.args.plots:
|
if self.args.plots:
|
||||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
|
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
|
||||||
@ -212,7 +212,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||||
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
||||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||||
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
||||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||||
check_requirements('pycocotools>=2.0.6')
|
check_requirements('pycocotools>=2.0.6')
|
||||||
from pycocotools.coco import COCO # noqa
|
from pycocotools.coco import COCO # noqa
|
||||||
@ -230,7 +230,7 @@ class DetectionValidator(BaseValidator):
|
|||||||
eval.summarize()
|
eval.summarize()
|
||||||
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
|
stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(f'pycocotools unable to run: {e}')
|
LOGGER.warning(f'pycocotools unable to run: {e}')
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,11 +68,10 @@ class SegmentationPredictor(DetectionPredictor):
|
|||||||
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
|
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
|
||||||
|
|
||||||
# Mask plotting
|
# Mask plotting
|
||||||
self.annotator.masks(
|
if self.args.save or self.args.show:
|
||||||
mask.masks,
|
im_gpu = torch.as_tensor(im0, dtype=torch.float16, device=mask.masks.device).permute(
|
||||||
colors=[colors(x, True) for x in det.cls],
|
2, 0, 1).flip(0).contiguous() / 255 if self.args.retina_masks else im[idx]
|
||||||
im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(self.device).permute(2, 0, 1).flip(0).contiguous() /
|
self.annotator.masks(masks=mask.masks, colors=[colors(x, True) for x in det.cls], im_gpu=im_gpu)
|
||||||
255 if self.args.retina_masks else im[idx])
|
|
||||||
|
|
||||||
# Write results
|
# Write results
|
||||||
for j, d in enumerate(reversed(det)):
|
for j, d in enumerate(reversed(det)):
|
||||||
|
@ -32,10 +32,7 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
|
|||||||
|
|
||||||
def get_validator(self):
|
def get_validator(self):
|
||||||
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
|
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
|
||||||
return v8.segment.SegmentationValidator(self.test_loader,
|
return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||||
save_dir=self.save_dir,
|
|
||||||
logger=self.console,
|
|
||||||
args=copy(self.args))
|
|
||||||
|
|
||||||
def criterion(self, preds, batch):
|
def criterion(self, preds, batch):
|
||||||
if not hasattr(self, 'compute_loss'):
|
if not hasattr(self, 'compute_loss'):
|
||||||
@ -86,10 +83,6 @@ class SegLoss(Loss):
|
|||||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||||
|
|
||||||
masks = batch['masks'].to(self.device).float()
|
|
||||||
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
|
||||||
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
|
|
||||||
|
|
||||||
# pboxes
|
# pboxes
|
||||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||||
|
|
||||||
@ -103,10 +96,15 @@ class SegLoss(Loss):
|
|||||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||||
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||||
|
|
||||||
# bbox loss
|
|
||||||
if fg_mask.sum():
|
if fg_mask.sum():
|
||||||
|
# bbox loss
|
||||||
loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
|
loss[0], loss[3] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes / stride_tensor,
|
||||||
target_scores, target_scores_sum, fg_mask)
|
target_scores, target_scores_sum, fg_mask)
|
||||||
|
# masks loss
|
||||||
|
masks = batch['masks'].to(self.device).float()
|
||||||
|
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
||||||
|
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
if fg_mask[i].sum():
|
if fg_mask[i].sum():
|
||||||
mask_idx = target_gt_idx[i][fg_mask[i]]
|
mask_idx = target_gt_idx[i][fg_mask[i]]
|
||||||
@ -121,9 +119,9 @@ class SegLoss(Loss):
|
|||||||
marea) # seg loss
|
marea) # seg loss
|
||||||
# WARNING: Uncomment lines below in case of Multi-GPU DDP unused gradient errors
|
# WARNING: Uncomment lines below in case of Multi-GPU DDP unused gradient errors
|
||||||
# else:
|
# else:
|
||||||
# loss[1] += proto.sum() * 0
|
# loss[1] += proto.sum() * 0 + pred_masks.sum() * 0
|
||||||
# else:
|
# else:
|
||||||
# loss[1] += proto.sum() * 0
|
# loss[1] += proto.sum() * 0 + pred_masks.sum() * 0
|
||||||
|
|
||||||
loss[0] *= self.hyp.box # box gain
|
loss[0] *= self.hyp.box # box gain
|
||||||
loss[1] *= self.hyp.box / batch_size # seg gain
|
loss[1] *= self.hyp.box / batch_size # seg gain
|
||||||
|
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ultralytics.yolo.utils import DEFAULT_CFG, NUM_THREADS, ops
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, NUM_THREADS, ops
|
||||||
from ultralytics.yolo.utils.checks import check_requirements
|
from ultralytics.yolo.utils.checks import check_requirements
|
||||||
from ultralytics.yolo.utils.metrics import SegmentMetrics, box_iou, mask_iou
|
from ultralytics.yolo.utils.metrics import SegmentMetrics, box_iou, mask_iou
|
||||||
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
from ultralytics.yolo.utils.plotting import output_to_target, plot_images
|
||||||
@ -16,8 +16,8 @@ from ultralytics.yolo.v8.detect import DetectionValidator
|
|||||||
|
|
||||||
class SegmentationValidator(DetectionValidator):
|
class SegmentationValidator(DetectionValidator):
|
||||||
|
|
||||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None):
|
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None):
|
||||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
super().__init__(dataloader, save_dir, pbar, args)
|
||||||
self.args.task = 'segment'
|
self.args.task = 'segment'
|
||||||
self.metrics = SegmentMetrics(save_dir=self.save_dir)
|
self.metrics = SegmentMetrics(save_dir=self.save_dir)
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
|
||||||
|
|
||||||
def finalize_metrics(self, *args, **kwargs):
|
def finalize_metrics(self, *args, **kwargs):
|
||||||
self.metrics.speed = dict(zip(self.metrics.speed.keys(), self.speed))
|
self.metrics.speed = self.speed
|
||||||
|
|
||||||
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
||||||
"""
|
"""
|
||||||
@ -207,7 +207,7 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||||
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
|
||||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||||
self.logger.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
LOGGER.info(f'\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...')
|
||||||
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
|
||||||
check_requirements('pycocotools>=2.0.6')
|
check_requirements('pycocotools>=2.0.6')
|
||||||
from pycocotools.coco import COCO # noqa
|
from pycocotools.coco import COCO # noqa
|
||||||
@ -228,7 +228,7 @@ class SegmentationValidator(DetectionValidator):
|
|||||||
stats[self.metrics.keys[idx + 1]], stats[
|
stats[self.metrics.keys[idx + 1]], stats[
|
||||||
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
|
self.metrics.keys[idx]] = eval.stats[:2] # update mAP50-95 and mAP50
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(f'pycocotools unable to run: {e}')
|
LOGGER.warning(f'pycocotools unable to run: {e}')
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user