mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-24 05:55:51 +08:00
ultralytics 8.0.83
Neptune AI logging addition (#2130)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Snyk bot <snyk-bot@snyk.io> Co-authored-by: Toutatis64 <Toutatis64@users.noreply.github.com> Co-authored-by: M. Tolga Cangöz <46008593+standardAI@users.noreply.github.com> Co-authored-by: Talia Bender <85292283+taliabender@users.noreply.github.com> Co-authored-by: Ophélie Le Mentec <17216799+ouphi@users.noreply.github.com> Co-authored-by: Kadir Şahin <68073829+ssahinnkadir@users.noreply.github.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
This commit is contained in:
parent
55a03ad85f
commit
6c082ebd6f
@ -3,7 +3,7 @@
|
|||||||
# Image is aarch64-compatible for Apple M1 and other ARM architectures i.e. Jetson Nano and Raspberry Pi
|
# Image is aarch64-compatible for Apple M1 and other ARM architectures i.e. Jetson Nano and Raspberry Pi
|
||||||
|
|
||||||
# Start FROM Ubuntu image https://hub.docker.com/_/ubuntu
|
# Start FROM Ubuntu image https://hub.docker.com/_/ubuntu
|
||||||
FROM arm64v8/ubuntu:rolling
|
FROM arm64v8/ubuntu:22.10
|
||||||
|
|
||||||
# Downloads to user config dir
|
# Downloads to user config dir
|
||||||
ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/
|
ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
# Image is CPU-optimized for ONNX, OpenVINO and PyTorch YOLOv8 deployments
|
# Image is CPU-optimized for ONNX, OpenVINO and PyTorch YOLOv8 deployments
|
||||||
|
|
||||||
# Start FROM Ubuntu image https://hub.docker.com/_/ubuntu
|
# Start FROM Ubuntu image https://hub.docker.com/_/ubuntu
|
||||||
FROM ubuntu:rolling
|
FROM ubuntu:22.10
|
||||||
|
|
||||||
# Downloads to user config dir
|
# Downloads to user config dir
|
||||||
ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/
|
ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/
|
||||||
|
@ -56,7 +56,7 @@ whether each source can be used in streaming mode with `stream=True` ✅ and an
|
|||||||
|
|
||||||
|
|
||||||
## Arguments
|
## Arguments
|
||||||
`model.predict` accepts multiple arguments that control the predction operation. These arguments can be passed directly to `model.predict`:
|
`model.predict` accepts multiple arguments that control the prediction operation. These arguments can be passed directly to `model.predict`:
|
||||||
!!! example
|
!!! example
|
||||||
```
|
```
|
||||||
model.predict(source, save=True, imgsz=320, conf=0.5)
|
model.predict(source, save=True, imgsz=320, conf=0.5)
|
||||||
|
@ -3,11 +3,6 @@
|
|||||||
:::ultralytics.hub.utils.Traces
|
:::ultralytics.hub.utils.Traces
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
# check_dataset_disk_space
|
|
||||||
---
|
|
||||||
:::ultralytics.hub.utils.check_dataset_disk_space
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
# request_with_credentials
|
# request_with_credentials
|
||||||
---
|
---
|
||||||
:::ultralytics.hub.utils.request_with_credentials
|
:::ultralytics.hub.utils.request_with_credentials
|
||||||
|
@ -8,6 +8,11 @@
|
|||||||
:::ultralytics.yolo.utils.downloads.unzip_file
|
:::ultralytics.yolo.utils.downloads.unzip_file
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
||||||
|
# check_disk_space
|
||||||
|
---
|
||||||
|
:::ultralytics.yolo.utils.downloads.check_disk_space
|
||||||
|
<br><br>
|
||||||
|
|
||||||
# safe_download
|
# safe_download
|
||||||
---
|
---
|
||||||
:::ultralytics.yolo.utils.downloads.safe_download
|
:::ultralytics.yolo.utils.downloads.safe_download
|
||||||
|
2
setup.py
2
setup.py
@ -46,7 +46,7 @@ setup(
|
|||||||
'Intended Audience :: Developers',
|
'Intended Audience :: Developers',
|
||||||
'Intended Audience :: Education',
|
'Intended Audience :: Education',
|
||||||
'Intended Audience :: Science/Research',
|
'Intended Audience :: Science/Research',
|
||||||
'License :: OSI Approved :: GNU Affero General Public License v3 (AGPLv3)',
|
'License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)',
|
||||||
'Programming Language :: Python :: 3',
|
'Programming Language :: Python :: 3',
|
||||||
'Programming Language :: Python :: 3.7',
|
'Programming Language :: Python :: 3.7',
|
||||||
'Programming Language :: Python :: 3.8',
|
'Programming Language :: Python :: 3.8',
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = '8.0.82'
|
__version__ = '8.0.83'
|
||||||
|
|
||||||
from ultralytics.hub import start
|
from ultralytics.hub import start
|
||||||
from ultralytics.yolo.engine.model import YOLO
|
from ultralytics.yolo.engine.model import YOLO
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
import glob
|
import glob
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -10,10 +11,11 @@ from typing import Optional
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import psutil
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT
|
from ..utils import LOCAL_RANK, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT
|
||||||
from .utils import HELP_URL, IMG_FORMATS
|
from .utils import HELP_URL, IMG_FORMATS
|
||||||
|
|
||||||
|
|
||||||
@ -63,14 +65,10 @@ class BaseDataset(Dataset):
|
|||||||
self.augment = augment
|
self.augment = augment
|
||||||
self.single_cls = single_cls
|
self.single_cls = single_cls
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
|
|
||||||
self.im_files = self.get_img_files(self.img_path)
|
self.im_files = self.get_img_files(self.img_path)
|
||||||
self.labels = self.get_labels()
|
self.labels = self.get_labels()
|
||||||
self.update_labels(include_class=classes) # single_cls and include_class
|
self.update_labels(include_class=classes) # single_cls and include_class
|
||||||
|
self.ni = len(self.labels) # number of images
|
||||||
self.ni = len(self.labels)
|
|
||||||
|
|
||||||
# Rect stuff
|
|
||||||
self.rect = rect
|
self.rect = rect
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
@ -80,6 +78,8 @@ class BaseDataset(Dataset):
|
|||||||
self.set_rectangle()
|
self.set_rectangle()
|
||||||
|
|
||||||
# Cache stuff
|
# Cache stuff
|
||||||
|
if cache == 'ram' and not self.check_cache_ram():
|
||||||
|
cache = False
|
||||||
self.ims = [None] * self.ni
|
self.ims = [None] * self.ni
|
||||||
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
|
self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
|
||||||
if cache:
|
if cache:
|
||||||
@ -148,7 +148,7 @@ class BaseDataset(Dataset):
|
|||||||
|
|
||||||
def cache_images(self, cache):
|
def cache_images(self, cache):
|
||||||
"""Cache images to memory or disk."""
|
"""Cache images to memory or disk."""
|
||||||
gb = 0 # Gigabytes of cached images
|
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||||
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni
|
||||||
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
|
fcn = self.cache_images_to_disk if cache == 'disk' else self.load_image
|
||||||
with ThreadPool(NUM_THREADS) as pool:
|
with ThreadPool(NUM_THREADS) as pool:
|
||||||
@ -156,11 +156,11 @@ class BaseDataset(Dataset):
|
|||||||
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
pbar = tqdm(enumerate(results), total=self.ni, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
|
||||||
for i, x in pbar:
|
for i, x in pbar:
|
||||||
if cache == 'disk':
|
if cache == 'disk':
|
||||||
gb += self.npy_files[i].stat().st_size
|
b += self.npy_files[i].stat().st_size
|
||||||
else: # 'ram'
|
else: # 'ram'
|
||||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||||
gb += self.ims[i].nbytes
|
b += self.ims[i].nbytes
|
||||||
pbar.desc = f'{self.prefix}Caching images ({gb / 1E9:.1f}GB {cache})'
|
pbar.desc = f'{self.prefix}Caching images ({b / gb:.1f}GB {cache})'
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
def cache_images_to_disk(self, i):
|
def cache_images_to_disk(self, i):
|
||||||
@ -169,6 +169,24 @@ class BaseDataset(Dataset):
|
|||||||
if not f.exists():
|
if not f.exists():
|
||||||
np.save(f.as_posix(), cv2.imread(self.im_files[i]))
|
np.save(f.as_posix(), cv2.imread(self.im_files[i]))
|
||||||
|
|
||||||
|
def check_cache_ram(self, safety_margin=0.5):
|
||||||
|
"""Check image caching requirements vs available memory."""
|
||||||
|
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||||
|
n = min(self.ni, 30) # extrapolate from 30 random images
|
||||||
|
for _ in range(n):
|
||||||
|
im = cv2.imread(random.choice(self.im_files)) # sample image
|
||||||
|
ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
|
||||||
|
b += im.nbytes * ratio ** 2
|
||||||
|
mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
|
||||||
|
mem = psutil.virtual_memory()
|
||||||
|
cache = mem_required < mem.available # to cache or not to cache, that is the question
|
||||||
|
if not cache:
|
||||||
|
LOGGER.info(f'{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images '
|
||||||
|
f'with {int(safety_margin * 100)}% safety margin but only '
|
||||||
|
f'{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, '
|
||||||
|
f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
|
||||||
|
return cache
|
||||||
|
|
||||||
def set_rectangle(self):
|
def set_rectangle(self):
|
||||||
"""Sets the shape of bounding boxes for YOLO detections as rectangles."""
|
"""Sets the shape of bounding boxes for YOLO detections as rectangles."""
|
||||||
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
||||||
|
@ -469,31 +469,27 @@ class YOLO:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def names(self):
|
def names(self):
|
||||||
"""
|
"""Returns class names of the loaded model."""
|
||||||
Returns class names of the loaded model.
|
|
||||||
"""
|
|
||||||
return self.model.names if hasattr(self.model, 'names') else None
|
return self.model.names if hasattr(self.model, 'names') else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
"""
|
"""Returns device if PyTorch model."""
|
||||||
Returns device if PyTorch model
|
|
||||||
"""
|
|
||||||
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
|
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def transforms(self):
|
def transforms(self):
|
||||||
"""
|
"""Returns transform of the loaded model."""
|
||||||
Returns transform of the loaded model.
|
|
||||||
"""
|
|
||||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||||
|
|
||||||
def add_callback(self, event: str, func):
|
def add_callback(self, event: str, func):
|
||||||
"""
|
"""Add a callback."""
|
||||||
Add callback
|
|
||||||
"""
|
|
||||||
self.callbacks[event].append(func)
|
self.callbacks[event].append(func)
|
||||||
|
|
||||||
|
def clear_callback(self, event: str):
|
||||||
|
"""Clear all event callbacks."""
|
||||||
|
self.callbacks[event] = []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reset_ckpt_args(args):
|
def _reset_ckpt_args(args):
|
||||||
"""Reset arguments when loading a PyTorch model."""
|
"""Reset arguments when loading a PyTorch model."""
|
||||||
|
@ -734,3 +734,26 @@ ENVIRONMENT = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter'
|
|||||||
'Docker' if is_docker() else platform.system()
|
'Docker' if is_docker() else platform.system()
|
||||||
TESTS_RUNNING = is_pytest_running() or is_github_actions_ci()
|
TESTS_RUNNING = is_pytest_running() or is_github_actions_ci()
|
||||||
set_sentry()
|
set_sentry()
|
||||||
|
|
||||||
|
# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------------
|
||||||
|
imshow_ = cv2.imshow # copy to avoid recursion errors
|
||||||
|
|
||||||
|
|
||||||
|
def imread(filename, flags=cv2.IMREAD_COLOR):
|
||||||
|
return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
|
||||||
|
|
||||||
|
|
||||||
|
def imwrite(filename, img):
|
||||||
|
try:
|
||||||
|
cv2.imencode(Path(filename).suffix, img)[1].tofile(filename)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def imshow(path, im):
|
||||||
|
imshow_(path.encode('unicode_escape').decode(), im)
|
||||||
|
|
||||||
|
|
||||||
|
if Path(inspect.stack()[0].filename).parent.parent.as_posix() in inspect.stack()[-1].filename:
|
||||||
|
cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
|
||||||
|
@ -200,11 +200,12 @@ def add_integration_callbacks(instance):
|
|||||||
from .comet import callbacks as comet_callbacks
|
from .comet import callbacks as comet_callbacks
|
||||||
from .hub import callbacks as hub_callbacks
|
from .hub import callbacks as hub_callbacks
|
||||||
from .mlflow import callbacks as mf_callbacks
|
from .mlflow import callbacks as mf_callbacks
|
||||||
|
from .neptune import callbacks as neptune_callbacks
|
||||||
from .raytune import callbacks as tune_callbacks
|
from .raytune import callbacks as tune_callbacks
|
||||||
from .tensorboard import callbacks as tb_callbacks
|
from .tensorboard import callbacks as tb_callbacks
|
||||||
from .wb import callbacks as wb_callbacks
|
from .wb import callbacks as wb_callbacks
|
||||||
|
|
||||||
for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks, mf_callbacks, tune_callbacks, wb_callbacks:
|
for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks, mf_callbacks, tune_callbacks, wb_callbacks, neptune_callbacks:
|
||||||
for k, v in x.items():
|
for k, v in x.items():
|
||||||
if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
|
if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
|
||||||
instance.callbacks[k].append(v) # callback[name].append(func)
|
instance.callbacks[k].append(v) # callback[name].append(func)
|
||||||
|
@ -52,19 +52,12 @@ def on_fit_epoch_end(trainer):
|
|||||||
run.log_metrics(metrics=metrics_dict, step=trainer.epoch)
|
run.log_metrics(metrics=metrics_dict, step=trainer.epoch)
|
||||||
|
|
||||||
|
|
||||||
def on_model_save(trainer):
|
|
||||||
"""Logs model and metrics to mlflow on save."""
|
|
||||||
if mlflow:
|
|
||||||
run.log_artifact(trainer.last)
|
|
||||||
|
|
||||||
|
|
||||||
def on_train_end(trainer):
|
def on_train_end(trainer):
|
||||||
"""Called at end of train loop to log model artifact info."""
|
"""Called at end of train loop to log model artifact info."""
|
||||||
if mlflow:
|
if mlflow:
|
||||||
root_dir = Path(__file__).resolve().parents[3]
|
root_dir = Path(__file__).resolve().parents[3]
|
||||||
|
run.log_artifact(trainer.last)
|
||||||
run.log_artifact(trainer.best)
|
run.log_artifact(trainer.best)
|
||||||
model_uri = f'runs:/{run_id}/'
|
|
||||||
run.register_model(model_uri, experiment_name)
|
|
||||||
run.pyfunc.log_model(artifact_path=experiment_name,
|
run.pyfunc.log_model(artifact_path=experiment_name,
|
||||||
code_path=[str(root_dir)],
|
code_path=[str(root_dir)],
|
||||||
artifacts={'model_path': str(trainer.save_dir)},
|
artifacts={'model_path': str(trainer.save_dir)},
|
||||||
@ -74,5 +67,4 @@ def on_train_end(trainer):
|
|||||||
callbacks = {
|
callbacks = {
|
||||||
'on_pretrain_routine_end': on_pretrain_routine_end,
|
'on_pretrain_routine_end': on_pretrain_routine_end,
|
||||||
'on_fit_epoch_end': on_fit_epoch_end,
|
'on_fit_epoch_end': on_fit_epoch_end,
|
||||||
'on_model_save': on_model_save,
|
|
||||||
'on_train_end': on_train_end} if mlflow else {}
|
'on_train_end': on_train_end} if mlflow else {}
|
||||||
|
105
ultralytics/yolo/utils/callbacks/neptune.py
Normal file
105
ultralytics/yolo/utils/callbacks/neptune.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
import matplotlib.image as mpimg
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
|
||||||
|
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
|
||||||
|
|
||||||
|
try:
|
||||||
|
import neptune
|
||||||
|
from neptune.types import File
|
||||||
|
|
||||||
|
assert not TESTS_RUNNING # do not log pytest
|
||||||
|
assert hasattr(neptune, '__version__')
|
||||||
|
except (ImportError, AssertionError):
|
||||||
|
neptune = None
|
||||||
|
|
||||||
|
run = None # NeptuneAI experiment logger instance
|
||||||
|
|
||||||
|
|
||||||
|
def _log_scalars(scalars, step=0):
|
||||||
|
"""Log scalars to the NeptuneAI experiment logger."""
|
||||||
|
if run:
|
||||||
|
for k, v in scalars.items():
|
||||||
|
run[k].append(value=v, step=step)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_images(imgs_dict, group=''):
|
||||||
|
"""Log scalars to the NeptuneAI experiment logger."""
|
||||||
|
if run:
|
||||||
|
for k, v in imgs_dict.items():
|
||||||
|
run[f'{group}/{k}'].upload(File(v))
|
||||||
|
|
||||||
|
|
||||||
|
def _log_plot(title, plot_path):
|
||||||
|
"""Log plots to the NeptuneAI experiment logger."""
|
||||||
|
"""
|
||||||
|
Log image as plot in the plot section of NeptuneAI
|
||||||
|
|
||||||
|
arguments:
|
||||||
|
title (str) Title of the plot
|
||||||
|
plot_path (PosixPath or str) Path to the saved image file
|
||||||
|
"""
|
||||||
|
img = mpimg.imread(plot_path)
|
||||||
|
fig = plt.figure()
|
||||||
|
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
||||||
|
ax.imshow(img)
|
||||||
|
run[f'Plots/{title}'].upload(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def on_pretrain_routine_start(trainer):
|
||||||
|
"""Callback function called before the training routine starts."""
|
||||||
|
try:
|
||||||
|
global run
|
||||||
|
run = neptune.init_run(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, tags=['YOLOv8'])
|
||||||
|
run['Configuration/Hyperparameters'] = {k: '' if v is None else v for k, v in vars(trainer.args).items()}
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.warning(f'WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}')
|
||||||
|
|
||||||
|
|
||||||
|
def on_train_epoch_end(trainer):
|
||||||
|
"""Callback function called at end of each training epoch."""
|
||||||
|
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
|
||||||
|
_log_scalars(trainer.lr, trainer.epoch + 1)
|
||||||
|
if trainer.epoch == 1:
|
||||||
|
_log_images({f.stem: str(f) for f in trainer.save_dir.glob('train_batch*.jpg')}, 'Mosaic')
|
||||||
|
|
||||||
|
|
||||||
|
def on_fit_epoch_end(trainer):
|
||||||
|
"""Callback function called at end of each fit (train+val) epoch."""
|
||||||
|
if run and trainer.epoch == 0:
|
||||||
|
model_info = {
|
||||||
|
'parameters': get_num_params(trainer.model),
|
||||||
|
'GFLOPs': round(get_flops(trainer.model), 3),
|
||||||
|
'speed(ms)': round(trainer.validator.speed['inference'], 3)}
|
||||||
|
run['Configuration/Model'] = model_info
|
||||||
|
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def on_val_end(validator):
|
||||||
|
"""Callback function called at end of each validation."""
|
||||||
|
if run:
|
||||||
|
# Log val_labels and val_pred
|
||||||
|
_log_images({f.stem: str(f) for f in validator.save_dir.glob('val*.jpg')}, 'Validation')
|
||||||
|
|
||||||
|
|
||||||
|
def on_train_end(trainer):
|
||||||
|
"""Callback function called at end of training."""
|
||||||
|
if run:
|
||||||
|
# Log final results, CM matrix + PR plots
|
||||||
|
files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
|
||||||
|
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
|
||||||
|
for f in files:
|
||||||
|
_log_plot(title=f.stem, plot_path=f)
|
||||||
|
# Log the final model
|
||||||
|
run[f'weights/{trainer.args.name or trainer.args.task}/{str(trainer.best.name)}'].upload(File(str(
|
||||||
|
trainer.best)))
|
||||||
|
run.stop()
|
||||||
|
|
||||||
|
|
||||||
|
callbacks = {
|
||||||
|
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||||
|
'on_train_epoch_end': on_train_epoch_end,
|
||||||
|
'on_fit_epoch_end': on_fit_epoch_end,
|
||||||
|
'on_val_end': on_val_end,
|
||||||
|
'on_train_end': on_train_end} if neptune else {}
|
Loading…
x
Reference in New Issue
Block a user