mirror of
https://github.com/THU-MIG/yolov10.git
synced 2025-05-23 21:44:22 +08:00
Set workers=0
for MPS Train and Val modes (#4697)
This commit is contained in:
parent
2bc6e647c7
commit
a4fabfdacf
1
.github/workflows/codeql.yaml
vendored
1
.github/workflows/codeql.yaml
vendored
@ -5,6 +5,7 @@ name: "CodeQL"
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 1 * *'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
|
@ -107,7 +107,7 @@ class BaseTrainer:
|
||||
print_args(vars(self.args))
|
||||
|
||||
# Device
|
||||
if self.device.type == 'cpu':
|
||||
if self.device.type in ('cpu', 'mps'):
|
||||
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||
|
||||
# Model and Dataset
|
||||
|
@ -144,7 +144,7 @@ class BaseValidator:
|
||||
else:
|
||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||
|
||||
if self.device.type == 'cpu':
|
||||
if self.device.type in ('cpu', 'mps'):
|
||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||
if not pt:
|
||||
self.args.rect = False
|
||||
|
@ -196,17 +196,27 @@ def add_integration_callbacks(instance):
|
||||
instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
|
||||
of callback lists.
|
||||
"""
|
||||
from .clearml import callbacks as clearml_cb
|
||||
|
||||
# Load HUB callbacks
|
||||
from .hub import callbacks
|
||||
|
||||
# Load training callbacks
|
||||
if 'Trainer' in instance.__class__.__name__:
|
||||
from .clearml import callbacks as clear_cb
|
||||
from .comet import callbacks as comet_cb
|
||||
from .dvc import callbacks as dvc_cb
|
||||
from .hub import callbacks as hub_cb
|
||||
from .mlflow import callbacks as mlflow_cb
|
||||
from .neptune import callbacks as neptune_cb
|
||||
from .raytune import callbacks as tune_cb
|
||||
from .tensorboard import callbacks as tensorboard_cb
|
||||
from .tensorboard import callbacks as tb_cb
|
||||
from .wb import callbacks as wb_cb
|
||||
callbacks.update({**clear_cb, **comet_cb, **dvc_cb, **mlflow_cb, **neptune_cb, **tune_cb, **tb_cb, **wb_cb})
|
||||
|
||||
for x in clearml_cb, comet_cb, hub_cb, mlflow_cb, neptune_cb, tune_cb, tensorboard_cb, wb_cb, dvc_cb:
|
||||
for k, v in x.items():
|
||||
# Load export callbacks (patch to avoid CoreML protobuf error)
|
||||
if 'Exporter' in instance.__class__.__name__:
|
||||
from .tensorboard import callbacks as tb_cb
|
||||
callbacks.update(tb_cb)
|
||||
|
||||
for k, v in callbacks.items():
|
||||
if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
|
||||
instance.callbacks[k].append(v) # callback[name].append(func)
|
||||
|
@ -1,12 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import re
|
||||
|
||||
import matplotlib.image as mpimg
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
@ -15,8 +9,8 @@ try:
|
||||
from clearml import Task
|
||||
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
||||
|
||||
assert hasattr(clearml, '__version__') # verify package is not directory
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
clearml = None
|
||||
|
||||
@ -29,6 +23,8 @@ def _log_debug_samples(files, title='Debug Samples') -> None:
|
||||
files (list): A list of file paths in PosixPath format.
|
||||
title (str): A title that groups together images with the same values.
|
||||
"""
|
||||
import re
|
||||
|
||||
if task := Task.current_task():
|
||||
for f in files:
|
||||
if f.exists():
|
||||
@ -48,6 +44,9 @@ def _log_plot(title, plot_path) -> None:
|
||||
title (str): The title of the plot.
|
||||
plot_path (str): The path to the saved image file.
|
||||
"""
|
||||
import matplotlib.image as mpimg
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
img = mpimg.imread(plot_path)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
||||
@ -103,6 +102,7 @@ def on_fit_epoch_end(trainer):
|
||||
value=trainer.epoch_time,
|
||||
iteration=trainer.epoch)
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
for k, v in model_info_for_loggers(trainer).items():
|
||||
task.get_logger().report_single_value(k, v)
|
||||
|
||||
|
@ -1,10 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
@ -12,8 +8,9 @@ try:
|
||||
import comet_ml
|
||||
|
||||
assert hasattr(comet_ml, '__version__') # verify package is not directory
|
||||
except (ImportError, AssertionError):
|
||||
comet_ml = None
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Ensures certain logging functions only run for supported tasks
|
||||
COMET_SUPPORTED_TASKS = ['detect']
|
||||
@ -24,6 +21,9 @@ LABEL_PLOT_NAMES = 'labels', 'labels_correlogram'
|
||||
|
||||
_comet_image_prediction_count = 0
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
comet_ml = None
|
||||
|
||||
|
||||
def _get_comet_mode():
|
||||
return os.getenv('COMET_MODE', 'online')
|
||||
@ -327,6 +327,7 @@ def on_fit_epoch_end(trainer):
|
||||
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
|
||||
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
|
||||
if curr_epoch == 1:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
|
||||
|
||||
if not save_assets:
|
||||
|
@ -1,37 +1,37 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources as pkg
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['dvc'] is True # verify integration is enabled
|
||||
from importlib.metadata import version
|
||||
|
||||
import dvclive
|
||||
|
||||
assert hasattr(dvclive, '__version__') # verify package is not directory
|
||||
|
||||
import os
|
||||
import re
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources as pkg
|
||||
|
||||
ver = version('dvclive')
|
||||
if pkg.parse_version(ver) < pkg.parse_version('2.11.0'):
|
||||
LOGGER.debug(f'DVCLive is detected but version {ver} is incompatible (>=2.11 required).')
|
||||
dvclive = None # noqa: F811
|
||||
except (ImportError, AssertionError, TypeError):
|
||||
dvclive = None
|
||||
|
||||
# DVCLive logger instance
|
||||
live = None
|
||||
_processed_plots = {}
|
||||
|
||||
# `on_fit_epoch_end` is called on final validation (probably need to be fixed)
|
||||
# for now this is the way we distinguish final evaluation of the best model vs
|
||||
# last epoch validation
|
||||
# `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we
|
||||
# distinguish final evaluation of the best model vs last epoch validation
|
||||
_training_epoch = False
|
||||
|
||||
except (ImportError, AssertionError, TypeError):
|
||||
dvclive = None
|
||||
|
||||
|
||||
def _log_images(path, prefix=''):
|
||||
if live:
|
||||
@ -103,6 +103,7 @@ def on_fit_epoch_end(trainer):
|
||||
live.log_metric(metric, value)
|
||||
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
for metric, value in model_info_for_loggers(trainer).items():
|
||||
live.log_metric(metric, value, plot=False)
|
||||
|
||||
|
@ -5,7 +5,6 @@ from time import time
|
||||
|
||||
from ultralytics.hub.utils import HUB_WEB_ROOT, PREFIX, events
|
||||
from ultralytics.utils import LOGGER, SETTINGS
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
@ -24,6 +23,7 @@ def on_fit_epoch_end(trainer):
|
||||
# Upload metrics after val end
|
||||
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
|
||||
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
||||
if time() - session.timers['metrics'] > session.rate_limits['metrics']:
|
||||
|
@ -1,17 +1,16 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
|
||||
from ultralytics.utils import LOGGER, ROOT, SETTINGS, TESTS_RUNNING, colorstr
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['mlflow'] is True # verify integration is enabled
|
||||
import mlflow
|
||||
|
||||
assert hasattr(mlflow, '__version__') # verify package is not directory
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
mlflow = None
|
||||
|
||||
@ -56,11 +55,10 @@ def on_fit_epoch_end(trainer):
|
||||
def on_train_end(trainer):
|
||||
"""Called at end of train loop to log model artifact info."""
|
||||
if mlflow:
|
||||
root_dir = Path(__file__).resolve().parents[3]
|
||||
run.log_artifact(trainer.last)
|
||||
run.log_artifact(trainer.best)
|
||||
run.pyfunc.log_model(artifact_path=experiment_name,
|
||||
code_path=[str(root_dir)],
|
||||
code_path=[str(ROOT.parent)],
|
||||
artifacts={'model_path': str(trainer.save_dir)},
|
||||
python_model=run.pyfunc.PythonModel())
|
||||
|
||||
|
@ -1,10 +1,6 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import matplotlib.image as mpimg
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
@ -13,11 +9,12 @@ try:
|
||||
from neptune.types import File
|
||||
|
||||
assert hasattr(neptune, '__version__')
|
||||
except (ImportError, AssertionError):
|
||||
neptune = None
|
||||
|
||||
run = None # NeptuneAI experiment logger instance
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
neptune = None
|
||||
|
||||
|
||||
def _log_scalars(scalars, step=0):
|
||||
"""Log scalars to the NeptuneAI experiment logger."""
|
||||
@ -42,6 +39,9 @@ def _log_plot(title, plot_path):
|
||||
title (str) Title of the plot
|
||||
plot_path (PosixPath or str) Path to the saved image file
|
||||
"""
|
||||
import matplotlib.image as mpimg
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
img = mpimg.imread(plot_path)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
|
||||
@ -70,6 +70,7 @@ def on_train_epoch_end(trainer):
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Callback function called at end of each fit (train+val) epoch."""
|
||||
if run and trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
run['Configuration/Model'] = model_info_for_loggers(trainer)
|
||||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||
|
||||
|
@ -7,6 +7,7 @@ try:
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air import session
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
tune = None
|
||||
|
||||
|
@ -8,12 +8,12 @@ try:
|
||||
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['tensorboard'] is True # verify integration is enabled
|
||||
WRITER = None # TensorBoard SummaryWriter instance
|
||||
|
||||
except (ImportError, AssertionError, TypeError):
|
||||
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
|
||||
SummaryWriter = None
|
||||
|
||||
WRITER = None # TensorBoard SummaryWriter instance
|
||||
|
||||
|
||||
def _log_scalars(scalars, step=0):
|
||||
"""Logs scalar values to TensorBoard."""
|
||||
@ -72,4 +72,4 @@ callbacks = {
|
||||
'on_pretrain_routine_start': on_pretrain_routine_start,
|
||||
'on_train_start': on_train_start,
|
||||
'on_fit_epoch_end': on_fit_epoch_end,
|
||||
'on_batch_end': on_batch_end}
|
||||
'on_batch_end': on_batch_end} if SummaryWriter else {}
|
||||
|
@ -9,11 +9,12 @@ try:
|
||||
import wandb as wb
|
||||
|
||||
assert hasattr(wb, '__version__')
|
||||
except (ImportError, AssertionError):
|
||||
wb = None
|
||||
|
||||
_processed_plots = {}
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
wb = None
|
||||
|
||||
|
||||
def _log_plots(plots, step):
|
||||
for name, params in plots.items():
|
||||
|
@ -273,7 +273,7 @@ def safe_download(url,
|
||||
"""
|
||||
|
||||
# Check if the URL is a Google Drive link
|
||||
gdrive = 'drive.google.com' in url
|
||||
gdrive = url.startswith('https://drive.google.com/')
|
||||
if gdrive:
|
||||
url, file = get_google_drive_file_info(url)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user